# dit for video from: https://github.com/world-model-eval/world-model-eval/blob/master/src/world_model_eval/model.py import torch from torch import nn import torch.nn.functional as F import einops import math import functools from typing import Sequence, Optional, Union, Dict, Tuple import sys from enum import Enum class StrEnum(str, Enum): def __str__(self): return str(self.value) class AttentionType(StrEnum): SPATIAL = "spatial" TEMPORAL = "temporal" class RotaryType(StrEnum): STANDARD = "standard" PIXEL = "pixel" @functools.lru_cache def rope_nd( shape: Sequence[int], dim: int = 64, base: float = 10_000.0, rotary_type: RotaryType = RotaryType.STANDARD, *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: D = len(shape) assert dim % (2 * D) == 0, ( f"`dim` must be divisible by 2 × D (got dim={dim}, D={D})" ) dim_per_axis = dim // D half = dim_per_axis // 2 if rotary_type == RotaryType.STANDARD: inv_freq = 1.0 / ( base ** (torch.arange(half, device=device, dtype=dtype) / half) ) coords = [torch.arange(n, device=device, dtype=dtype) for n in shape] elif rotary_type == RotaryType.PIXEL: inv_freq = ( torch.linspace(1.0, 256.0 / 2, half, device=device, dtype=dtype) * math.pi ) coords = [ torch.linspace(-1, +1, steps=n, device=device, dtype=dtype) for n in shape ] else: raise NotImplementedError(f"invalid rotary type: {rotary_type}") mesh = torch.meshgrid(*coords, indexing="ij") embeddings = [] for pos in mesh: theta = pos.unsqueeze(-1) * inv_freq emb_axis = torch.cat([torch.cos(theta), torch.sin(theta)], dim=-1) embeddings.append(emb_axis) return torch.cat(embeddings, dim=-1) def rotate_half(x: torch.Tensor) -> torch.Tensor: x = x.view(*x.shape[:-1], -1, 2) x1, x2 = x.unbind(-1) return torch.stack((-x2, x1), dim=-1).flatten(-2) def rope_mix(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: cos = torch.repeat_interleave(cos, 2, dim=-1) sin = torch.repeat_interleave(sin, 2, dim=-1) return x * cos + rotate_half(x) * sin def apply_rope_nd( q: torch.Tensor, k: torch.Tensor, shape: Tuple[int, ...], rotary_type: RotaryType, *, base: float = 10_000.0, ) -> Tuple[torch.Tensor, torch.Tensor]: dim = q.shape[-1] rope = rope_nd( shape, dim, base, rotary_type=rotary_type, dtype=q.dtype, device=q.device ) rope = rope.view(*shape, len(shape), 2, -1) cos, sin = rope.unbind(-2) cos = cos.reshape(*shape, -1) sin = sin.reshape(*shape, -1) q_rot = rope_mix(q, cos, sin) k_rot = rope_mix(k, cos, sin) return q_rot, k_rot class FinalLayer(nn.Module): def __init__(self, dim: int, patch_size: int, out_channels: int) -> None: super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(dim, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 2, bias=True) ) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: _, _, H, W, _ = x.shape m = self.adaLN_modulation(c) m = einops.repeat(m, "b t d -> b t h w d", h=H, w=W).chunk(2, dim=-1) x = self.linear(self.norm(x) * (1 + m[1]) + m[0]) return x class Attention(nn.Module): def __init__( self, dim: int, num_heads: int, is_causal: bool, attention_type: AttentionType, rotary_type: RotaryType = RotaryType.STANDARD, ) -> None: super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads self.dim = dim self.is_causal = is_causal self.attention_type = attention_type self.rotary_type = rotary_type self.qkv_proj = nn.Linear(dim, dim * 3, bias=False) self.out_proj = nn.Linear(dim, dim) def forward(self, x: torch.Tensor): B, T, H, W, D = x.shape if self.attention_type == AttentionType.SPATIAL: x = einops.rearrange(x, "b t h w d -> (b t) h w d") elif self.attention_type == AttentionType.TEMPORAL: x = einops.rearrange(x, "b t h w d -> (b h w) t d") else: raise NotImplementedError(f"invalid attention type: {self.attention_type}") sequence_shape = x.shape[1:-1] q, k, v = self.qkv_proj(x).chunk(3, dim=-1) q = einops.rearrange(q, "B ... (head d) -> B head ... d", head=self.num_heads) k = einops.rearrange(k, "B ... (head d) -> B head ... d", head=self.num_heads) v = einops.rearrange(v, "B ... (head d) -> B head ... d", head=self.num_heads) q, k = apply_rope_nd(q, k, sequence_shape, rotary_type=self.rotary_type) # Flatten the sequence dimension q = einops.rearrange(q, "B head ... d -> B head (...) d") k = einops.rearrange(k, "B head ... d -> B head (...) d") v = einops.rearrange(v, "B head ... d -> B head (...) d") x = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal) x = einops.rearrange(x, "B head seq d -> B seq (head d)") x = self.out_proj(x) if self.attention_type == AttentionType.SPATIAL: x = einops.rearrange(x, "(b t) (h w) d -> b t h w d", t=T, h=H, w=W) elif self.attention_type == AttentionType.TEMPORAL: x = einops.rearrange(x, "(b h w) t d -> b t h w d", h=H, w=W) return x class DiTBlock(nn.Module): def __init__( self, dim: int, num_heads: int, attention_type: AttentionType, rotary_type: RotaryType, is_causal: bool, ) -> None: super().__init__() self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 6, bias=True) ) self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.attn = Attention( dim, num_heads, is_causal=is_causal, attention_type=attention_type, rotary_type=rotary_type, ) self.ffwd = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(approximate="tanh"), nn.Linear(dim * 4, dim), ) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: _, _, H, W, _ = x.shape m = self.adaLN_modulation(c) m = einops.repeat(m, "b t d -> b t h w d", h=H, w=W).chunk(6, dim=-1) x = x + self.attn(self.norm1(x) * (1 + m[1]) + m[0]) * m[2] x = x + self.ffwd(self.norm2(x) * (1 + m[4]) + m[3]) * m[5] return x class Block(nn.Module): def __init__( self, dim: int, num_heads: int, rope_config: Optional[Dict[AttentionType, RotaryType]] = None, temporal_causal: bool = True, ) -> None: super().__init__() self.s_block = DiTBlock( dim, num_heads, is_causal=False, attention_type=AttentionType.SPATIAL, rotary_type=rope_config[AttentionType.SPATIAL] if rope_config else RotaryType.STANDARD, ) self.t_block = DiTBlock( dim, num_heads, is_causal=temporal_causal, attention_type=AttentionType.TEMPORAL, rotary_type=rope_config[AttentionType.TEMPORAL] if rope_config else RotaryType.STANDARD, ) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: x = self.s_block(x, c) x = self.t_block(x, c) return x class ActionEmbedder(nn.Module): def __init__(self, action_dim: int, dim: int, compress_rate: int = 4): super().__init__() self.compress_rate = compress_rate self.mlp_in = nn.Sequential( nn.Linear(action_dim, dim), nn.SiLU(), nn.Linear(dim, dim), ) if compress_rate == 4: self.downsample = nn.Sequential( nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1), nn.SiLU(), nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1), ) elif compress_rate == 2: self.downsample = nn.Sequential( nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1), ) else: self.downsample = nn.Identity() self.mlp_out = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim), ) def forward(self, action: torch.Tensor) -> torch.Tensor: # action: [B, L, action_dim] where L = compress_rate * (T-1) + 1 x = self.mlp_in(action) # [B, L, dim] if self.compress_rate > 1: x = x.permute(0, 2, 1) # [B, dim, L] x = self.downsample(x) # [B, dim, T] x = x.permute(0, 2, 1) # [B, T, dim] x = self.mlp_out(x) # [B, T, dim] return x class DiT(nn.Module): def __init__( self, in_channels: int = 4, patch_size: int = 2, dim: int = 1152, num_layers: int = 28, num_heads: int = 16, action_dim: int = 0, action_compress_rate: int = 4, max_frames: int = 16, rope_config: Optional[Dict[AttentionType, RotaryType]] = None, action_dropout_prob: float = 0.1, temporal_causal: bool = True, ) -> None: super().__init__() self.in_channels = in_channels self.patch_size = patch_size self.action_dim = action_dim self.action_compress_rate = action_compress_rate self.action_dropout_prob = action_dropout_prob self.x_proj = nn.Conv2d( in_channels, dim, kernel_size=patch_size, stride=patch_size ) self.timestep_mlp = nn.Sequential( nn.Linear(256, dim, bias=True), nn.SiLU(), nn.Linear(dim, dim, bias=True), ) self.action_embedder = ActionEmbedder(action_dim, dim, compress_rate=action_compress_rate) self.blocks = nn.ModuleList( [Block(dim, num_heads, rope_config, temporal_causal=temporal_causal) for _ in range(num_layers)] ) self.final_layer = FinalLayer(dim, patch_size, in_channels) self.max_frames = max_frames self.initialize_weights() def timestep_embedding( self, t: torch.Tensor, dim: int = 256, max_period: int = 10000 ) -> torch.Tensor: # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def initialize_weights(self) -> None: # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_proj.bias, 0) # Initialize timestep embedding MLP: nn.init.normal_(self.timestep_mlp[0].weight, std=0.02) nn.init.normal_(self.timestep_mlp[2].weight, std=0.02) # Initialize action embedder: for module in self.action_embedder.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') # Zero-out adaLN modulation layers in DiT blocks: for block in self.blocks: nn.init.constant_(block.s_block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.s_block.adaLN_modulation[-1].bias, 0) nn.init.constant_(block.t_block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.t_block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def patchify(self, x: torch.Tensor) -> torch.Tensor: B, T, H, W, C = x.shape x = einops.rearrange(x, "b t h w c -> (b t) c h w") x = self.x_proj(x) x = einops.rearrange(x, "(b t) d h w -> b t h w d", t=T) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: return einops.rearrange( x, "b h w (p1 p2 c) -> b (h p1) (w p2) c", p1=self.patch_size, p2=self.patch_size, c=self.in_channels, ) def get_null_cond(self, action: torch.Tensor) -> torch.Tensor: null_action = torch.zeros_like(action) # NOTE: all-zero action is still conditional (meaning "do not move"), so we # need to reserve the last component of the action vector to indicate null. null_action[..., -1] = 1 return null_action def get_cond(self, t: torch.Tensor, action: torch.Tensor) -> torch.Tensor: B, T = t.shape t = einops.rearrange(t, "b t -> (b t)") t_freq = self.timestep_embedding(t) c = self.timestep_mlp(t_freq) c = einops.rearrange(c, "(b t) d -> b t d", t=T) if self.training and self.action_dropout_prob > 0: should_drop = torch.rand((B, 1, 1), device=action.device) < self.action_dropout_prob null_action = self.get_null_cond(action) action = torch.where(should_drop, null_action, action) c += self.action_embedder(action) return c def forward( self, x: torch.Tensor, t: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: B, T, H, W, C = x.shape x = self.patchify(x) c = self.get_cond(t, action) for block in self.blocks: x = block(x, c) x = self.final_layer(x, c) x = einops.rearrange(x, "b t h w d -> (b t) h w d") x = self.unpatchify(x) x = einops.rearrange(x, "(b t) h w c -> b t h w c", t=T) return x if __name__ == "__main__": # Test DiT instantiation and forward pass device = "cuda" if torch.cuda.is_available() else "cpu" # Configure RoPE for both spatial and temporal attention rope_config = { AttentionType.SPATIAL: RotaryType.STANDARD, AttentionType.TEMPORAL: RotaryType.STANDARD } # Initialize a small DiT model for testing (bidirectional temporal attention) model = DiT( in_channels=4, # e.g., latent channels patch_size=2, dim=256, # hidden dimension num_layers=4, num_heads=8, action_dim=16, max_frames=16, rope_config=rope_config, temporal_causal=False # Test bidirectional temporal attention ).to(device) # Dummy inputs: (B, T, H, W, C) B, T, H, W, C = 2, 9, 32, 32, 4 x = torch.randn(B, T, H, W, C).to(device) t = torch.randint(0, 1000, (B, T)).to(device) # Action shape should be (B, 4*(T-1)+1, action_dim) for compress_rate=4 L = 4 * (T - 1) + 1 action = torch.randn(B, L, 16).to(device) print(f"Running forward pass on device: {device}...") output = model(x, t, action) print(f"Input shape: {x.shape}") print(f"Timestep shape: {t.shape}") print(f"Action shape: {action.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape, "Output shape mismatch!" print("Forward pass successful!")