neural-boy / src /models /diffusion /inner_model.py
sachinkumarsingh's picture
Upload folder using huggingface_hub
d548197 verified
Raw
History Blame Contribute Delete
1.71 kB
from dataclasses import dataclass
from typing import List, Optional
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet
@dataclass
class InnerModelConfig:
img_channels: int
num_steps_conditioning: int
cond_channels: int
depths: List[int]
channels: List[int]
attn_depths: List[bool]
num_actions: Optional[int] = None
class InnerModel(nn.Module):
def __init__(self, cfg: InnerModelConfig) -> None:
super().__init__()
self.noise_emb = FourierFeatures(cfg.cond_channels)
self.act_emb = nn.Sequential(
nn.Embedding(cfg.num_actions, cfg.cond_channels // cfg.num_steps_conditioning),
nn.Flatten(), # b t e -> b (t e)
)
self.cond_proj = nn.Sequential(
nn.Linear(cfg.cond_channels, cfg.cond_channels),
nn.SiLU(),
nn.Linear(cfg.cond_channels, cfg.cond_channels),
)
self.conv_in = Conv3x3((cfg.num_steps_conditioning + 1) * cfg.img_channels, cfg.channels[0])
self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths)
self.norm_out = GroupNorm(cfg.channels[0])
self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels)
nn.init.zeros_(self.conv_out.weight)
def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, obs: Tensor, act: Tensor) -> Tensor:
cond = self.cond_proj(self.noise_emb(c_noise) + self.act_emb(act))
x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1))
x, _, _ = self.unet(x, cond)
x = self.conv_out(F.silu(self.norm_out(x)))
return x