| | |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | import einops |
| | import math |
| | import functools |
| | from typing import Sequence, Optional, Union, Dict, Tuple |
| | import sys |
| | from enum import Enum |
| |
|
| | class StrEnum(str, Enum): |
| | def __str__(self): |
| | return str(self.value) |
| |
|
| |
|
| | class AttentionType(StrEnum): |
| | SPATIAL = "spatial" |
| | TEMPORAL = "temporal" |
| |
|
| |
|
| | class RotaryType(StrEnum): |
| | STANDARD = "standard" |
| | PIXEL = "pixel" |
| |
|
| |
|
| | @functools.lru_cache |
| | def rope_nd( |
| | shape: Sequence[int], |
| | dim: int = 64, |
| | base: float = 10_000.0, |
| | rotary_type: RotaryType = RotaryType.STANDARD, |
| | *, |
| | dtype: torch.dtype = torch.float32, |
| | device: Optional[torch.device] = None, |
| | ) -> torch.Tensor: |
| | D = len(shape) |
| | assert dim % (2 * D) == 0, ( |
| | f"`dim` must be divisible by 2 × D (got dim={dim}, D={D})" |
| | ) |
| |
|
| | dim_per_axis = dim // D |
| | half = dim_per_axis // 2 |
| | if rotary_type == RotaryType.STANDARD: |
| | inv_freq = 1.0 / ( |
| | base ** (torch.arange(half, device=device, dtype=dtype) / half) |
| | ) |
| | coords = [torch.arange(n, device=device, dtype=dtype) for n in shape] |
| | elif rotary_type == RotaryType.PIXEL: |
| | inv_freq = ( |
| | torch.linspace(1.0, 256.0 / 2, half, device=device, dtype=dtype) * math.pi |
| | ) |
| | coords = [ |
| | torch.linspace(-1, +1, steps=n, device=device, dtype=dtype) for n in shape |
| | ] |
| | else: |
| | raise NotImplementedError(f"invalid rotary type: {rotary_type}") |
| |
|
| | mesh = torch.meshgrid(*coords, indexing="ij") |
| |
|
| | embeddings = [] |
| | for pos in mesh: |
| | theta = pos.unsqueeze(-1) * inv_freq |
| | emb_axis = torch.cat([torch.cos(theta), torch.sin(theta)], dim=-1) |
| | embeddings.append(emb_axis) |
| | return torch.cat(embeddings, dim=-1) |
| |
|
| |
|
| | def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| | x = x.view(*x.shape[:-1], -1, 2) |
| | x1, x2 = x.unbind(-1) |
| | return torch.stack((-x2, x1), dim=-1).flatten(-2) |
| |
|
| |
|
| | def rope_mix(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| | cos = torch.repeat_interleave(cos, 2, dim=-1) |
| | sin = torch.repeat_interleave(sin, 2, dim=-1) |
| | return x * cos + rotate_half(x) * sin |
| |
|
| |
|
| | def apply_rope_nd( |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | shape: Tuple[int, ...], |
| | rotary_type: RotaryType, |
| | *, |
| | base: float = 10_000.0, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | dim = q.shape[-1] |
| | rope = rope_nd( |
| | shape, dim, base, rotary_type=rotary_type, dtype=q.dtype, device=q.device |
| | ) |
| | rope = rope.view(*shape, len(shape), 2, -1) |
| | cos, sin = rope.unbind(-2) |
| | cos = cos.reshape(*shape, -1) |
| | sin = sin.reshape(*shape, -1) |
| |
|
| | q_rot = rope_mix(q, cos, sin) |
| | k_rot = rope_mix(k, cos, sin) |
| | return q_rot, k_rot |
| |
|
| |
|
| | class FinalLayer(nn.Module): |
| | def __init__(self, dim: int, patch_size: int, out_channels: int) -> None: |
| | super().__init__() |
| | self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| | self.linear = nn.Linear(dim, patch_size * patch_size * out_channels, bias=True) |
| | self.adaLN_modulation = nn.Sequential( |
| | nn.SiLU(), nn.Linear(dim, dim * 2, bias=True) |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| | _, _, H, W, _ = x.shape |
| | m = self.adaLN_modulation(c) |
| | m = einops.repeat(m, "b t d -> b t h w d", h=H, w=W).chunk(2, dim=-1) |
| | x = self.linear(self.norm(x) * (1 + m[1]) + m[0]) |
| | return x |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | is_causal: bool, |
| | attention_type: AttentionType, |
| | rotary_type: RotaryType = RotaryType.STANDARD, |
| | ) -> None: |
| | super().__init__() |
| | assert dim % num_heads == 0 |
| | self.num_heads = num_heads |
| | self.dim = dim |
| | self.is_causal = is_causal |
| | self.attention_type = attention_type |
| | self.rotary_type = rotary_type |
| | self.qkv_proj = nn.Linear(dim, dim * 3, bias=False) |
| | self.out_proj = nn.Linear(dim, dim) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | B, T, H, W, D = x.shape |
| |
|
| | if self.attention_type == AttentionType.SPATIAL: |
| | x = einops.rearrange(x, "b t h w d -> (b t) h w d") |
| | elif self.attention_type == AttentionType.TEMPORAL: |
| | x = einops.rearrange(x, "b t h w d -> (b h w) t d") |
| | else: |
| | raise NotImplementedError(f"invalid attention type: {self.attention_type}") |
| | sequence_shape = x.shape[1:-1] |
| |
|
| | q, k, v = self.qkv_proj(x).chunk(3, dim=-1) |
| | q = einops.rearrange(q, "B ... (head d) -> B head ... d", head=self.num_heads) |
| | k = einops.rearrange(k, "B ... (head d) -> B head ... d", head=self.num_heads) |
| | v = einops.rearrange(v, "B ... (head d) -> B head ... d", head=self.num_heads) |
| |
|
| | q, k = apply_rope_nd(q, k, sequence_shape, rotary_type=self.rotary_type) |
| | |
| | q = einops.rearrange(q, "B head ... d -> B head (...) d") |
| | k = einops.rearrange(k, "B head ... d -> B head (...) d") |
| | v = einops.rearrange(v, "B head ... d -> B head (...) d") |
| |
|
| | x = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal) |
| | x = einops.rearrange(x, "B head seq d -> B seq (head d)") |
| | x = self.out_proj(x) |
| |
|
| | if self.attention_type == AttentionType.SPATIAL: |
| | x = einops.rearrange(x, "(b t) (h w) d -> b t h w d", t=T, h=H, w=W) |
| | elif self.attention_type == AttentionType.TEMPORAL: |
| | x = einops.rearrange(x, "(b h w) t d -> b t h w d", h=H, w=W) |
| | return x |
| |
|
| |
|
| | class DiTBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | attention_type: AttentionType, |
| | rotary_type: RotaryType, |
| | is_causal: bool, |
| | ) -> None: |
| | super().__init__() |
| | self.adaLN_modulation = nn.Sequential( |
| | nn.SiLU(), nn.Linear(dim, dim * 6, bias=True) |
| | ) |
| | self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| | self.attn = Attention( |
| | dim, |
| | num_heads, |
| | is_causal=is_causal, |
| | attention_type=attention_type, |
| | rotary_type=rotary_type, |
| | ) |
| | self.ffwd = nn.Sequential( |
| | nn.Linear(dim, dim * 4), |
| | nn.GELU(approximate="tanh"), |
| | nn.Linear(dim * 4, dim), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| | _, _, H, W, _ = x.shape |
| | m = self.adaLN_modulation(c) |
| | m = einops.repeat(m, "b t d -> b t h w d", h=H, w=W).chunk(6, dim=-1) |
| | x = x + self.attn(self.norm1(x) * (1 + m[1]) + m[0]) * m[2] |
| | x = x + self.ffwd(self.norm2(x) * (1 + m[4]) + m[3]) * m[5] |
| | return x |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | rope_config: Optional[Dict[AttentionType, RotaryType]] = None, |
| | temporal_causal: bool = True, |
| | ) -> None: |
| | super().__init__() |
| | self.s_block = DiTBlock( |
| | dim, |
| | num_heads, |
| | is_causal=False, |
| | attention_type=AttentionType.SPATIAL, |
| | rotary_type=rope_config[AttentionType.SPATIAL] |
| | if rope_config |
| | else RotaryType.STANDARD, |
| | ) |
| | self.t_block = DiTBlock( |
| | dim, |
| | num_heads, |
| | is_causal=temporal_causal, |
| | attention_type=AttentionType.TEMPORAL, |
| | rotary_type=rope_config[AttentionType.TEMPORAL] |
| | if rope_config |
| | else RotaryType.STANDARD, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| | x = self.s_block(x, c) |
| | x = self.t_block(x, c) |
| | return x |
| |
|
| |
|
| | class ActionEmbedder(nn.Module): |
| | def __init__(self, action_dim: int, dim: int, compress_rate: int = 4): |
| | super().__init__() |
| | self.compress_rate = compress_rate |
| | self.mlp_in = nn.Sequential( |
| | nn.Linear(action_dim, dim), |
| | nn.SiLU(), |
| | nn.Linear(dim, dim), |
| | ) |
| |
|
| | if compress_rate == 4: |
| | self.downsample = nn.Sequential( |
| | nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1), |
| | nn.SiLU(), |
| | nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1), |
| | ) |
| | elif compress_rate == 2: |
| | self.downsample = nn.Sequential( |
| | nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1), |
| | ) |
| | else: |
| | self.downsample = nn.Identity() |
| |
|
| | self.mlp_out = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(dim, dim), |
| | ) |
| |
|
| | def forward(self, action: torch.Tensor) -> torch.Tensor: |
| | |
| | x = self.mlp_in(action) |
| |
|
| | if self.compress_rate > 1: |
| | x = x.permute(0, 2, 1) |
| | x = self.downsample(x) |
| | x = x.permute(0, 2, 1) |
| |
|
| | x = self.mlp_out(x) |
| | return x |
| |
|
| |
|
| | class DiT(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int = 4, |
| | patch_size: int = 2, |
| | dim: int = 1152, |
| | num_layers: int = 28, |
| | num_heads: int = 16, |
| | action_dim: int = 0, |
| | action_compress_rate: int = 4, |
| | max_frames: int = 16, |
| | rope_config: Optional[Dict[AttentionType, RotaryType]] = None, |
| | action_dropout_prob: float = 0.1, |
| | temporal_causal: bool = True, |
| | ) -> None: |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.patch_size = patch_size |
| | self.action_dim = action_dim |
| | self.action_compress_rate = action_compress_rate |
| | self.action_dropout_prob = action_dropout_prob |
| | self.x_proj = nn.Conv2d( |
| | in_channels, dim, kernel_size=patch_size, stride=patch_size |
| | ) |
| | self.timestep_mlp = nn.Sequential( |
| | nn.Linear(256, dim, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(dim, dim, bias=True), |
| | ) |
| | self.action_embedder = ActionEmbedder(action_dim, dim, compress_rate=action_compress_rate) |
| | self.blocks = nn.ModuleList( |
| | [Block(dim, num_heads, rope_config, temporal_causal=temporal_causal) for _ in range(num_layers)] |
| | ) |
| | self.final_layer = FinalLayer(dim, patch_size, in_channels) |
| | self.max_frames = max_frames |
| | self.initialize_weights() |
| |
|
| | def timestep_embedding( |
| | self, t: torch.Tensor, dim: int = 256, max_period: int = 10000 |
| | ) -> torch.Tensor: |
| | |
| | half = dim // 2 |
| | freqs = torch.exp( |
| | -math.log(max_period) |
| | * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) |
| | / half |
| | ) |
| | args = t[:, None].float() * freqs[None] |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | if dim % 2: |
| | embedding = torch.cat( |
| | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 |
| | ) |
| | return embedding |
| |
|
| | def initialize_weights(self) -> None: |
| | |
| | def _basic_init(module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| |
|
| | self.apply(_basic_init) |
| |
|
| | |
| | w = self.x_proj.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.x_proj.bias, 0) |
| |
|
| | |
| | nn.init.normal_(self.timestep_mlp[0].weight, std=0.02) |
| | nn.init.normal_(self.timestep_mlp[2].weight, std=0.02) |
| |
|
| | |
| | for module in self.action_embedder.modules(): |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, std=0.02) |
| | elif isinstance(module, nn.Conv1d): |
| | nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
| |
|
| | |
| | for block in self.blocks: |
| | nn.init.constant_(block.s_block.adaLN_modulation[-1].weight, 0) |
| | nn.init.constant_(block.s_block.adaLN_modulation[-1].bias, 0) |
| | nn.init.constant_(block.t_block.adaLN_modulation[-1].weight, 0) |
| | nn.init.constant_(block.t_block.adaLN_modulation[-1].bias, 0) |
| |
|
| | |
| | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
| | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
| | nn.init.constant_(self.final_layer.linear.weight, 0) |
| | nn.init.constant_(self.final_layer.linear.bias, 0) |
| |
|
| | def patchify(self, x: torch.Tensor) -> torch.Tensor: |
| | B, T, H, W, C = x.shape |
| | x = einops.rearrange(x, "b t h w c -> (b t) c h w") |
| | x = self.x_proj(x) |
| | x = einops.rearrange(x, "(b t) d h w -> b t h w d", t=T) |
| | return x |
| |
|
| | def unpatchify(self, x: torch.Tensor) -> torch.Tensor: |
| | return einops.rearrange( |
| | x, |
| | "b h w (p1 p2 c) -> b (h p1) (w p2) c", |
| | p1=self.patch_size, |
| | p2=self.patch_size, |
| | c=self.in_channels, |
| | ) |
| |
|
| | def get_null_cond(self, action: torch.Tensor) -> torch.Tensor: |
| | null_action = torch.zeros_like(action) |
| | |
| | |
| | null_action[..., -1] = 1 |
| | return null_action |
| |
|
| | def get_cond(self, t: torch.Tensor, action: torch.Tensor) -> torch.Tensor: |
| | B, T = t.shape |
| | t = einops.rearrange(t, "b t -> (b t)") |
| | t_freq = self.timestep_embedding(t) |
| | c = self.timestep_mlp(t_freq) |
| | c = einops.rearrange(c, "(b t) d -> b t d", t=T) |
| | if self.training and self.action_dropout_prob > 0: |
| | should_drop = torch.rand((B, 1, 1), device=action.device) < self.action_dropout_prob |
| | null_action = self.get_null_cond(action) |
| | action = torch.where(should_drop, null_action, action) |
| | c += self.action_embedder(action) |
| | return c |
| |
|
| | def forward( |
| | self, x: torch.Tensor, t: torch.Tensor, action: torch.Tensor |
| | ) -> torch.Tensor: |
| | B, T, H, W, C = x.shape |
| | x = self.patchify(x) |
| | c = self.get_cond(t, action) |
| | for block in self.blocks: |
| | x = block(x, c) |
| | x = self.final_layer(x, c) |
| | x = einops.rearrange(x, "b t h w d -> (b t) h w d") |
| | x = self.unpatchify(x) |
| | x = einops.rearrange(x, "(b t) h w c -> b t h w c", t=T) |
| | return x |
| |
|
| | if __name__ == "__main__": |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | |
| | rope_config = { |
| | AttentionType.SPATIAL: RotaryType.STANDARD, |
| | AttentionType.TEMPORAL: RotaryType.STANDARD |
| | } |
| |
|
| | |
| | model = DiT( |
| | in_channels=4, |
| | patch_size=2, |
| | dim=256, |
| | num_layers=4, |
| | num_heads=8, |
| | action_dim=16, |
| | max_frames=16, |
| | rope_config=rope_config, |
| | temporal_causal=False |
| | ).to(device) |
| |
|
| | |
| | B, T, H, W, C = 2, 9, 32, 32, 4 |
| | x = torch.randn(B, T, H, W, C).to(device) |
| | t = torch.randint(0, 1000, (B, T)).to(device) |
| | |
| | |
| | L = 4 * (T - 1) + 1 |
| | action = torch.randn(B, L, 16).to(device) |
| |
|
| | print(f"Running forward pass on device: {device}...") |
| | output = model(x, t, action) |
| |
|
| | print(f"Input shape: {x.shape}") |
| | print(f"Timestep shape: {t.shape}") |
| | print(f"Action shape: {action.shape}") |
| | print(f"Output shape: {output.shape}") |
| |
|
| | assert output.shape == x.shape, "Output shape mismatch!" |
| | print("Forward pass successful!") |
| |
|