Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet | |
| class InnerModelConfig: | |
| img_channels: int | |
| num_steps_conditioning: int | |
| cond_channels: int | |
| depths: List[int] | |
| channels: List[int] | |
| attn_depths: List[bool] | |
| num_actions: Optional[int] = None # set by trainer after env creation | |
| is_upsampler: Optional[bool] = None # set by Denoiser | |
| class InnerModel(nn.Module): | |
| def __init__(self, cfg: InnerModelConfig) -> None: | |
| super().__init__() | |
| self.noise_emb = FourierFeatures(cfg.cond_channels) | |
| self.noise_cond_emb = FourierFeatures(cfg.cond_channels) | |
| self.act_emb = None if cfg.is_upsampler else nn.Sequential( | |
| nn.Embedding(cfg.num_actions, cfg.cond_channels // cfg.num_steps_conditioning), | |
| nn.Flatten(), # b t e -> b (t e) | |
| ) | |
| self.cond_proj = nn.Sequential( | |
| nn.Linear(cfg.cond_channels, cfg.cond_channels), | |
| nn.SiLU(), | |
| nn.Linear(cfg.cond_channels, cfg.cond_channels), | |
| ) | |
| self.conv_in = Conv3x3((cfg.num_steps_conditioning + int(cfg.is_upsampler) + 1) * cfg.img_channels, cfg.channels[0]) | |
| self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths) | |
| self.norm_out = GroupNorm(cfg.channels[0]) | |
| self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels) | |
| nn.init.zeros_(self.conv_out.weight) | |
| def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, c_noise_cond: Tensor, obs: Tensor, act: Optional[Tensor]) -> Tensor: | |
| if self.act_emb is not None: | |
| assert act.ndim == 2 or (act.ndim == 3 and act.size(2) == self.act_emb[0].num_embeddings and set(act.unique().tolist()).issubset(set([0, 1]))) | |
| act_emb = self.act_emb(act) if act.ndim == 2 else self.act_emb[1]((act.float() @ self.act_emb[0].weight)) | |
| else: | |
| assert act is None | |
| act_emb = 0 | |
| cond = self.cond_proj(self.noise_emb(c_noise) + self.noise_cond_emb(c_noise_cond) + act_emb) | |
| x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1)) | |
| x, _, _ = self.unet(x, cond) | |
| x = self.conv_out(F.silu(self.norm_out(x))) | |
| return x | |