PIWM / src /models /diffusion /inner_model.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
raw
history blame
2.36 kB
from dataclasses import dataclass
from typing import List, Optional
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet
@dataclass
class InnerModelConfig:
img_channels: int
num_steps_conditioning: int
cond_channels: int
depths: List[int]
channels: List[int]
attn_depths: List[bool]
num_actions: Optional[int] = None # set by trainer after env creation
is_upsampler: Optional[bool] = None # set by Denoiser
class InnerModel(nn.Module):
def __init__(self, cfg: InnerModelConfig) -> None:
super().__init__()
self.noise_emb = FourierFeatures(cfg.cond_channels)
self.noise_cond_emb = FourierFeatures(cfg.cond_channels)
self.act_emb = None if cfg.is_upsampler else nn.Sequential(
nn.Embedding(cfg.num_actions, cfg.cond_channels // cfg.num_steps_conditioning),
nn.Flatten(), # b t e -> b (t e)
)
self.cond_proj = nn.Sequential(
nn.Linear(cfg.cond_channels, cfg.cond_channels),
nn.SiLU(),
nn.Linear(cfg.cond_channels, cfg.cond_channels),
)
self.conv_in = Conv3x3((cfg.num_steps_conditioning + int(cfg.is_upsampler) + 1) * cfg.img_channels, cfg.channels[0])
self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths)
self.norm_out = GroupNorm(cfg.channels[0])
self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels)
nn.init.zeros_(self.conv_out.weight)
def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, c_noise_cond: Tensor, obs: Tensor, act: Optional[Tensor]) -> Tensor:
if self.act_emb is not None:
assert act.ndim == 2 or (act.ndim == 3 and act.size(2) == self.act_emb[0].num_embeddings and set(act.unique().tolist()).issubset(set([0, 1])))
act_emb = self.act_emb(act) if act.ndim == 2 else self.act_emb[1]((act.float() @ self.act_emb[0].weight))
else:
assert act is None
act_emb = 0
cond = self.cond_proj(self.noise_emb(c_noise) + self.noise_cond_emb(c_noise_cond) + act_emb)
x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1))
x, _, _ = self.unet(x, cond)
x = self.conv_out(F.silu(self.norm_out(x)))
return x