File size: 1,706 Bytes
d548197 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | from dataclasses import dataclass
from typing import List, Optional
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet
@dataclass
class InnerModelConfig:
img_channels: int
num_steps_conditioning: int
cond_channels: int
depths: List[int]
channels: List[int]
attn_depths: List[bool]
num_actions: Optional[int] = None
class InnerModel(nn.Module):
def __init__(self, cfg: InnerModelConfig) -> None:
super().__init__()
self.noise_emb = FourierFeatures(cfg.cond_channels)
self.act_emb = nn.Sequential(
nn.Embedding(cfg.num_actions, cfg.cond_channels // cfg.num_steps_conditioning),
nn.Flatten(), # b t e -> b (t e)
)
self.cond_proj = nn.Sequential(
nn.Linear(cfg.cond_channels, cfg.cond_channels),
nn.SiLU(),
nn.Linear(cfg.cond_channels, cfg.cond_channels),
)
self.conv_in = Conv3x3((cfg.num_steps_conditioning + 1) * cfg.img_channels, cfg.channels[0])
self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths)
self.norm_out = GroupNorm(cfg.channels[0])
self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels)
nn.init.zeros_(self.conv_out.weight)
def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, obs: Tensor, act: Tensor) -> Tensor:
cond = self.cond_proj(self.noise_emb(c_noise) + self.act_emb(act))
x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1))
x, _, _ = self.unet(x, cond)
x = self.conv_out(F.silu(self.norm_out(x)))
return x
|