| """ |
| References: |
| - DiT: https://github.com/facebookresearch/DiT/blob/main/models.py |
| - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py |
| - Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py |
| """ |
|
|
| from typing import Optional, Literal |
| import torch |
| from torch import nn |
| from .rotary_embedding_torch import RotaryEmbedding |
| from einops import rearrange |
| from .attention import SpatialAxialAttention, TemporalAxialAttention |
| from timm.models.vision_transformer import Mlp |
| from timm.layers.helpers import to_2tuple |
| import math |
| from collections import namedtuple |
| from typing import Optional, Callable |
|
|
| def modulate(x, shift, scale): |
| fixed_dims = [1] * len(shift.shape[1:]) |
| shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims) |
| scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims) |
| while shift.dim() < x.dim(): |
| shift = shift.unsqueeze(-2) |
| scale = scale.unsqueeze(-2) |
| return x * (1 + scale) + shift |
|
|
| def gate(x, g): |
| fixed_dims = [1] * len(g.shape[1:]) |
| g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims) |
| while g.dim() < x.dim(): |
| g = g.unsqueeze(-2) |
| return g * x |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """2D Image to Patch Embedding""" |
|
|
| def __init__( |
| self, |
| img_height=256, |
| img_width=256, |
| patch_size=16, |
| in_chans=3, |
| embed_dim=768, |
| norm_layer=None, |
| flatten=True, |
| ): |
| super().__init__() |
| img_size = (img_height, img_width) |
| patch_size = to_2tuple(patch_size) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
| self.flatten = flatten |
|
|
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x, random_sample=False): |
| B, C, H, W = x.shape |
| assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
| |
| x = self.proj(x) |
| if self.flatten: |
| x = rearrange(x, "B C H W -> B (H W) C") |
| else: |
| x = rearrange(x, "B C H W -> B H W C") |
| x = self.norm(x) |
| return x |
|
|
|
|
| class TimestepEmbedder(nn.Module): |
| """ |
| Embeds scalar timesteps into vector representations. |
| """ |
|
|
| def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
| nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
| self.frequency_embedding_size = frequency_embedding_size |
| self.freq_type = freq_type |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'): |
| """ |
| Create sinusoidal timestep embeddings. |
| :param t: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param dim: the dimension of the output. |
| :param max_period: controls the minimum frequency of the embeddings. |
| :return: an (N, D) Tensor of positional embeddings. |
| """ |
| |
| half = dim // 2 |
|
|
| if freq_type == 'time_step': |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) |
| elif freq_type == 'spatial': |
| freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi |
| elif freq_type == 'angle': |
| freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180 |
|
|
|
|
| args = t[:, None].float() * freqs[None] |
| |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
|
|
| class FinalLayer(nn.Module): |
| """ |
| The final layer of DiT. |
| """ |
|
|
| def __init__(self, hidden_size, patch_size, out_channels): |
| super().__init__() |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) |
|
|
| def forward(self, x, c): |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) |
| x = modulate(self.norm_final(x), shift, scale) |
| x = self.linear(x) |
| return x |
|
|
|
|
| MEMORY_TYPE_NAMES = ("anchor", "dynamic", "revisit") |
| MEMORY_TYPE_ANCHOR = 0 |
| MEMORY_TYPE_DYNAMIC = 1 |
| MEMORY_TYPE_REVISIT = 2 |
|
|
|
|
| class MemoryTokenCrossAttention(nn.Module): |
| def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, num_memory_types=3): |
| super().__init__() |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| approx_gelu = lambda: nn.GELU(approximate="tanh") |
| self.num_heads = num_heads |
| self.num_memory_types = num_memory_types |
| self.norm_q = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.norm_mem = nn.LayerNorm(hidden_size, eps=1e-6) |
| self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) |
| self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.mlp = Mlp( |
| in_features=hidden_size, |
| hidden_features=mlp_hidden_dim, |
| act_layer=approx_gelu, |
| drop=0, |
| ) |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) |
| self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size) |
| self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size)) |
| self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True)) |
| self.last_gate_mean = None |
| self.last_delta_ratio = None |
| self.last_valid_fraction = None |
| self.last_type_gate_mean = None |
| for type_name in MEMORY_TYPE_NAMES[:num_memory_types]: |
| setattr(self, f"last_type_gate_{type_name}_mean", None) |
| nn.init.normal_(self.memory_type_embed.weight, std=0.02) |
| self.reset_identity_init() |
|
|
| def reset_identity_init(self): |
| nn.init.constant_(self.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(self.adaLN_modulation[-1].bias, 0) |
| nn.init.constant_(self.memory_type_gate[-1].weight, 0) |
| nn.init.constant_(self.memory_type_gate[-1].bias, 0) |
|
|
| def _attend(self, query, memory_tokens, memory_token_mask=None, memory_token_gate=None): |
| if memory_token_mask is None and memory_token_gate is None: |
| out, _ = self.attn(query, memory_tokens, memory_tokens, need_weights=False) |
| return out, None |
|
|
| if memory_token_mask is None: |
| memory_token_mask = torch.ones( |
| memory_tokens.shape[:2], |
| device=memory_tokens.device, |
| dtype=torch.bool, |
| ) |
| else: |
| memory_token_mask = memory_token_mask.bool() |
| gate_tensor = None |
| if memory_token_gate is not None: |
| if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:2]): |
| raise ValueError( |
| f"memory_token_gate must have shape {tuple(memory_tokens.shape[:2])}, " |
| f"got {tuple(memory_token_gate.shape)}" |
| ) |
| gate_tensor = memory_token_gate.to(device=memory_tokens.device, dtype=query.dtype) |
| memory_token_mask = memory_token_mask & (gate_tensor > 0) |
| valid_rows = memory_token_mask.any(dim=1) |
| out = torch.zeros_like(query) |
| if valid_rows.any(): |
| attn_mask = None |
| key_padding_mask = ~memory_token_mask[valid_rows] |
| if gate_tensor is not None: |
| gate_bias = torch.log(gate_tensor[valid_rows].clamp_min(1.0e-6)) |
| gate_bias = gate_bias[:, None, :].expand(-1, query.shape[1], -1) |
| attn_mask = gate_bias.repeat_interleave(self.num_heads, dim=0) |
| float_padding_mask = torch.zeros_like(gate_tensor[valid_rows], dtype=query.dtype) |
| key_padding_mask = float_padding_mask.masked_fill(key_padding_mask, float("-inf")) |
| attended, _ = self.attn( |
| query[valid_rows], |
| memory_tokens[valid_rows], |
| memory_tokens[valid_rows], |
| key_padding_mask=key_padding_mask, |
| attn_mask=attn_mask, |
| need_weights=False, |
| ) |
| out[valid_rows] = attended.to(out.dtype) |
| return out, valid_rows |
|
|
| def _apply_memory_type(self, memory_tokens, memory_type_ids): |
| if memory_type_ids is None: |
| return memory_tokens |
| memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long) |
| type_embed = self.memory_type_embed(memory_type_ids).to(memory_tokens.dtype) |
| type_scale = self.memory_type_scale[memory_type_ids].to(memory_tokens.dtype) |
| while type_embed.dim() < memory_tokens.dim(): |
| type_embed = type_embed.unsqueeze(0) |
| type_scale = type_scale.unsqueeze(0) |
| return memory_tokens * type_scale + type_embed |
|
|
| def _store_type_gate_diagnostics(self, stage_gate): |
| with torch.no_grad(): |
| detached = stage_gate.detach().float() |
| self.last_type_gate_mean = detached.mean() |
| for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]): |
| setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean()) |
|
|
| def _type_stage_gate(self, c, memory_tokens, memory_type_ids): |
| if memory_type_ids is None: |
| return None |
| memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long) |
| stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype) |
| self._store_type_gate_diagnostics(stage_gate) |
| if memory_tokens.dim() == 4: |
| batch_size, num_frames, num_tokens = memory_tokens.shape[:3] |
| if memory_type_ids.dim() == 1: |
| gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, num_frames, num_tokens) |
| elif tuple(memory_type_ids.shape) == (batch_size, num_frames, num_tokens): |
| gather_ids = memory_type_ids |
| else: |
| raise ValueError( |
| "rank-4 memory_type_ids must have shape (M,) or (B,T,M), " |
| f"got {tuple(memory_type_ids.shape)}" |
| ) |
| return torch.gather(stage_gate, dim=-1, index=gather_ids) |
| if memory_tokens.dim() == 3: |
| batch_size, num_tokens = memory_tokens.shape[:2] |
| if memory_type_ids.dim() != 1: |
| raise ValueError("rank-3 memory_type_ids must have shape (M,)") |
| gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, stage_gate.shape[1], num_tokens) |
| return torch.gather(stage_gate, dim=-1, index=gather_ids).mean(dim=1) |
| raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}") |
|
|
| def _combine_memory_gate(self, memory_tokens, memory_token_gate, type_stage_gate): |
| combined_gate = type_stage_gate |
| if memory_token_gate is not None: |
| if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:-1]): |
| raise ValueError( |
| f"memory_token_gate must have shape {tuple(memory_tokens.shape[:-1])}, " |
| f"got {tuple(memory_token_gate.shape)}" |
| ) |
| stream_gate = memory_token_gate.to(device=memory_tokens.device, dtype=memory_tokens.dtype) |
| combined_gate = stream_gate if combined_gate is None else combined_gate * stream_gate |
| return combined_gate |
|
|
| def _valid_mask(self, valid_rows, batch_size, num_frames, dtype, device): |
| if valid_rows is None: |
| return None |
| valid_rows = valid_rows.to(device=device, dtype=dtype) |
| if valid_rows.numel() == batch_size: |
| return valid_rows.view(batch_size, 1, 1, 1, 1) |
| if valid_rows.numel() == batch_size * num_frames: |
| return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None] |
| raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}") |
|
|
| def _gate_valid_mask(self, valid_rows, batch_size, num_frames, dtype, device): |
| if valid_rows is None: |
| return None |
| valid_rows = valid_rows.to(device=device, dtype=dtype) |
| if valid_rows.numel() == batch_size: |
| return valid_rows.view(batch_size, 1, 1) |
| if valid_rows.numel() == batch_size * num_frames: |
| return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None] |
| raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}") |
|
|
| def _residual_gate(self, residual_gate, batch_size, num_frames, dtype, device): |
| if residual_gate is None: |
| return None |
| if not torch.is_tensor(residual_gate): |
| return torch.tensor(float(residual_gate), dtype=dtype, device=device).view(1, 1, 1, 1, 1) |
| gate_tensor = residual_gate.to(device=device, dtype=dtype) |
| if gate_tensor.dim() == 0: |
| gate_tensor = gate_tensor.view(1, 1, 1, 1, 1) |
| elif gate_tensor.dim() == 1: |
| if gate_tensor.numel() == batch_size: |
| gate_tensor = gate_tensor.view(batch_size, 1, 1, 1, 1) |
| elif gate_tensor.numel() == batch_size * num_frames: |
| gate_tensor = rearrange(gate_tensor, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None] |
| else: |
| raise ValueError(f"residual_gate has incompatible shape: {tuple(gate_tensor.shape)}") |
| elif gate_tensor.dim() == 2: |
| if tuple(gate_tensor.shape) != (batch_size, num_frames): |
| raise ValueError(f"residual_gate must have shape (B,T), got {tuple(gate_tensor.shape)}") |
| gate_tensor = gate_tensor[:, :, None, None, None] |
| elif gate_tensor.dim() == 3: |
| if tuple(gate_tensor.shape[:2]) != (batch_size, num_frames): |
| raise ValueError(f"residual_gate must start with (B,T), got {tuple(gate_tensor.shape)}") |
| gate_tensor = gate_tensor[:, :, :, None, None] |
| else: |
| while gate_tensor.dim() < 5: |
| gate_tensor = gate_tensor.unsqueeze(-1) |
| return gate_tensor |
|
|
| def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows): |
| with torch.no_grad(): |
| batch_size, num_frames = base.shape[:2] |
| gate_values = torch.cat( |
| [gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()], |
| dim=-1, |
| ) |
| gate_mask = self._gate_valid_mask( |
| valid_rows, |
| batch_size, |
| num_frames, |
| dtype=gate_values.dtype, |
| device=gate_values.device, |
| ) |
| if gate_mask is not None: |
| gate_values = gate_values * gate_mask |
| self.last_valid_fraction = valid_rows.detach().float().mean() |
| valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0) |
| self.last_gate_mean = gate_values.sum() / valid_count |
| else: |
| self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32) |
| self.last_gate_mean = gate_values.mean() |
|
|
| delta_norm = (output.detach().float() - base.detach().float()).norm() |
| base_norm = base.detach().float().norm() |
| self.last_delta_ratio = delta_norm / (base_norm + 1e-6) |
|
|
| def forward( |
| self, |
| x, |
| c, |
| memory_tokens, |
| memory_token_mask=None, |
| residual_base=None, |
| return_delta=False, |
| residual_gate=None, |
| memory_type_ids=None, |
| memory_token_gate=None, |
| ): |
| B, T, H, W, D = x.shape |
| if residual_base is None: |
| residual_base = x |
| m_shift_msa, m_scale_msa, m_gate_msa, m_shift_mlp, m_scale_mlp, m_gate_mlp = ( |
| self.adaLN_modulation(c).chunk(6, dim=-1) |
| ) |
| query_source = modulate(self.norm_q(x), m_shift_msa, m_scale_msa) |
| type_stage_gate = self._type_stage_gate(c, memory_tokens, memory_type_ids) |
| effective_token_gate = self._combine_memory_gate(memory_tokens, memory_token_gate, type_stage_gate) |
| if memory_tokens.dim() == 3: |
| query = rearrange(query_source, "b t h w d -> b (t h w) d") |
| memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids) |
| valid_rows = None |
| if memory_token_mask is not None: |
| if tuple(memory_token_mask.shape) != tuple(memory_tokens.shape[:2]): |
| raise ValueError( |
| f"legacy memory mask must have shape {tuple(memory_tokens.shape[:2])}, " |
| f"got {tuple(memory_token_mask.shape)}" |
| ) |
| out, valid_rows = self._attend( |
| query, |
| memory_tokens, |
| memory_token_mask=memory_token_mask, |
| memory_token_gate=effective_token_gate, |
| ) |
| out = rearrange(out, "b (t h w) d -> b t h w d", t=T, h=H, w=W) |
| elif memory_tokens.dim() == 4: |
| assert memory_tokens.shape[:2] == (B, T), ( |
| f"per-frame memory tokens must have shape (B, T, M, D), got {tuple(memory_tokens.shape)}" |
| ) |
| query = rearrange(query_source, "b t h w d -> (b t) (h w) d") |
| memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids) |
| memory_tokens = rearrange(memory_tokens, "b t m d -> (b t) m d") |
| if effective_token_gate is not None: |
| effective_token_gate = rearrange(effective_token_gate, "b t m -> (b t) m") |
| valid_rows = None |
| if memory_token_mask is not None: |
| expected_mask_shape = (B, T, memory_tokens.shape[1]) |
| if tuple(memory_token_mask.shape) != expected_mask_shape: |
| raise ValueError( |
| f"per-frame memory mask must have shape {expected_mask_shape}, " |
| f"got {tuple(memory_token_mask.shape)}" |
| ) |
| memory_token_mask = rearrange(memory_token_mask.bool(), "b t m -> (b t) m") |
| out, valid_rows = self._attend( |
| query, |
| memory_tokens, |
| memory_token_mask=memory_token_mask, |
| memory_token_gate=effective_token_gate, |
| ) |
| out = rearrange(out, "(b t) (h w) d -> b t h w d", b=B, t=T, h=H, w=W) |
| else: |
| raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}") |
|
|
| valid_mask = self._valid_mask(valid_rows, B, T, dtype=out.dtype, device=out.device) |
| residual_gate_tensor = self._residual_gate(residual_gate, B, T, dtype=out.dtype, device=out.device) |
| attn_delta = gate(out, m_gate_msa) |
| if valid_mask is not None: |
| attn_delta = attn_delta * valid_mask |
| if residual_gate_tensor is not None: |
| attn_delta = attn_delta * residual_gate_tensor |
| output = residual_base + attn_delta |
|
|
| mlp_delta = gate(self.mlp(modulate(self.norm_mlp(output), m_shift_mlp, m_scale_mlp)), m_gate_mlp) |
| if valid_mask is not None: |
| mlp_delta = mlp_delta * valid_mask |
| if residual_gate_tensor is not None: |
| mlp_delta = mlp_delta * residual_gate_tensor |
| output = output + mlp_delta |
| self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows) |
| if return_delta: |
| return attn_delta + mlp_delta |
| return output |
|
|
| class SpatioTemporalDiTBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size, |
| num_heads, |
| reference_length, |
| mlp_ratio=4.0, |
| is_causal=True, |
| spatial_rotary_emb: Optional[RotaryEmbedding] = None, |
| temporal_rotary_emb: Optional[RotaryEmbedding] = None, |
| use_memory_token_cross_attention=False, |
| ref_mode='sequential' |
| ): |
| super().__init__() |
| self.is_causal = is_causal |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
|
| self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.s_attn = SpatialAxialAttention( |
| hidden_size, |
| heads=num_heads, |
| dim_head=hidden_size // num_heads, |
| rotary_emb=spatial_rotary_emb |
| ) |
| self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.s_mlp = Mlp( |
| in_features=hidden_size, |
| hidden_features=mlp_hidden_dim, |
| act_layer=approx_gelu, |
| drop=0, |
| ) |
| self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) |
|
|
| self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.t_attn = TemporalAxialAttention( |
| hidden_size, |
| heads=num_heads, |
| dim_head=hidden_size // num_heads, |
| is_causal=is_causal, |
| rotary_emb=temporal_rotary_emb, |
| reference_length=reference_length |
| ) |
| self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.t_mlp = Mlp( |
| in_features=hidden_size, |
| hidden_features=mlp_hidden_dim, |
| act_layer=approx_gelu, |
| drop=0, |
| ) |
| self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) |
|
|
| self.reference_length = reference_length |
| self.use_memory_token_cross_attention = use_memory_token_cross_attention |
| if self.use_memory_token_cross_attention: |
| self.memory_token_cross_attn = MemoryTokenCrossAttention(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
|
|
| self.ref_mode = ref_mode |
|
|
| if self.ref_mode == 'parallel': |
| self.parallel_map = nn.Linear(hidden_size, hidden_size) |
|
|
| def _expand_memory_stream(self, tokens, mask, stream_gate, type_idx, batch_size, num_frames): |
| if tokens is None or tokens.shape[-2] == 0: |
| return None |
| if tokens.dim() == 3: |
| if tokens.shape[0] != batch_size: |
| raise ValueError(f"rank-3 memory tokens must start with B={batch_size}, got {tuple(tokens.shape)}") |
| tokens = tokens[:, None].expand(-1, num_frames, -1, -1) |
| if mask is None: |
| mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool) |
| elif mask.dim() == 2: |
| mask = mask[:, None].expand(-1, num_frames, -1) |
| elif mask.dim() != 3: |
| raise ValueError(f"rank-3 stream mask must have rank 2 or 3, got {tuple(mask.shape)}") |
| elif tokens.dim() == 4: |
| if tuple(tokens.shape[:2]) != (batch_size, num_frames): |
| raise ValueError( |
| f"rank-4 memory tokens must start with (B,T)={(batch_size, num_frames)}, " |
| f"got {tuple(tokens.shape)}" |
| ) |
| if mask is None: |
| mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool) |
| elif mask.dim() != 3: |
| raise ValueError(f"rank-4 stream mask must have rank 3, got {tuple(mask.shape)}") |
| else: |
| raise ValueError(f"memory stream tokens must be rank 3 or 4, got rank {tokens.dim()}") |
| if tuple(mask.shape) != tuple(tokens.shape[:3]): |
| raise ValueError(f"memory stream mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}") |
| gate_tensor = self._expand_memory_stream_gate(stream_gate, tokens) |
| type_ids = torch.full((tokens.shape[2],), int(type_idx), device=tokens.device, dtype=torch.long) |
| return tokens, mask.to(device=tokens.device, dtype=torch.bool), gate_tensor, type_ids |
|
|
| def _expand_memory_stream_gate(self, stream_gate, tokens): |
| batch_size, num_frames, num_tokens = tokens.shape[:3] |
| if stream_gate is None: |
| return torch.ones(tokens.shape[:3], device=tokens.device, dtype=tokens.dtype) |
| if not torch.is_tensor(stream_gate): |
| return torch.full(tokens.shape[:3], float(stream_gate), device=tokens.device, dtype=tokens.dtype) |
| gate_tensor = stream_gate.to(device=tokens.device, dtype=tokens.dtype) |
| if gate_tensor.dim() == 0: |
| return gate_tensor.view(1, 1, 1).expand(batch_size, num_frames, num_tokens) |
| if gate_tensor.dim() == 1: |
| if gate_tensor.numel() != batch_size: |
| raise ValueError(f"rank-1 memory gate must have B={batch_size} values, got {tuple(gate_tensor.shape)}") |
| return gate_tensor.view(batch_size, 1, 1).expand(batch_size, num_frames, num_tokens) |
| if gate_tensor.dim() == 2: |
| if tuple(gate_tensor.shape) == (batch_size, num_frames): |
| return gate_tensor[:, :, None].expand(batch_size, num_frames, num_tokens) |
| if tuple(gate_tensor.shape) == (batch_size, num_tokens): |
| return gate_tensor[:, None, :].expand(batch_size, num_frames, num_tokens) |
| raise ValueError( |
| f"rank-2 memory gate must have shape (B,T) or (B,M), got {tuple(gate_tensor.shape)}" |
| ) |
| if gate_tensor.dim() == 3: |
| if tuple(gate_tensor.shape) == (batch_size, num_frames, 1): |
| return gate_tensor.expand(batch_size, num_frames, num_tokens) |
| if tuple(gate_tensor.shape) == (batch_size, num_frames, num_tokens): |
| return gate_tensor |
| raise ValueError( |
| f"rank-3 memory gate must have shape (B,T,1) or (B,T,M), got {tuple(gate_tensor.shape)}" |
| ) |
| raise ValueError(f"memory gate rank must be <=3, got rank {gate_tensor.dim()}") |
|
|
| def _pack_typed_memory_streams( |
| self, |
| batch_size, |
| num_frames, |
| memory_tokens=None, |
| memory_token_mask=None, |
| memory_dynamic_tokens=None, |
| memory_dynamic_mask=None, |
| memory_retrieval_tokens=None, |
| memory_retrieval_mask=None, |
| memory_anchor_gate=None, |
| memory_dynamic_gate=None, |
| memory_retrieval_gate=None, |
| ): |
| streams = [] |
| for tokens, mask, stream_gate, type_idx in ( |
| (memory_tokens, memory_token_mask, memory_anchor_gate, MEMORY_TYPE_ANCHOR), |
| (memory_dynamic_tokens, memory_dynamic_mask, memory_dynamic_gate, MEMORY_TYPE_DYNAMIC), |
| (memory_retrieval_tokens, memory_retrieval_mask, memory_retrieval_gate, MEMORY_TYPE_REVISIT), |
| ): |
| expanded = self._expand_memory_stream(tokens, mask, stream_gate, type_idx, batch_size, num_frames) |
| if expanded is not None: |
| streams.append(expanded) |
| if not streams: |
| return None |
| packed_tokens = torch.cat([item[0] for item in streams], dim=2) |
| packed_mask = torch.cat([item[1] for item in streams], dim=2) |
| packed_gate = torch.cat([item[2] for item in streams], dim=2) |
| packed_type_ids = torch.cat([item[3] for item in streams], dim=0) |
| valid_gate = packed_gate.masked_fill(~packed_mask, 0) |
| residual_gate = valid_gate.max(dim=2).values |
| return packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate |
|
|
| def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False, |
| pose_cond=None, mode="training", c_action_cond=None, reference_length=None, |
| memory_tokens=None, memory_token_mask=None, memory_dynamic_tokens=None, memory_dynamic_mask=None, |
| memory_retrieval_tokens=None, memory_retrieval_mask=None, memory_anchor_gate=None, |
| memory_dynamic_gate=None, memory_retrieval_gate=None): |
| B, T, H, W, D = x.shape |
|
|
| |
| |
| s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1) |
| x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa) |
| x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp) |
|
|
| |
| if c_action_cond is not None: |
| t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1) |
| else: |
| t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1) |
| |
| x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa) |
| x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp) |
|
|
| if self.ref_mode == 'sequential': |
| x = x_t |
|
|
| if self.use_memory_token_cross_attention: |
| memory_base = x |
| packed_memory = self._pack_typed_memory_streams( |
| B, |
| T, |
| memory_tokens=memory_tokens, |
| memory_token_mask=memory_token_mask, |
| memory_dynamic_tokens=memory_dynamic_tokens, |
| memory_dynamic_mask=memory_dynamic_mask, |
| memory_retrieval_tokens=memory_retrieval_tokens, |
| memory_retrieval_mask=memory_retrieval_mask, |
| memory_anchor_gate=memory_anchor_gate, |
| memory_dynamic_gate=memory_dynamic_gate, |
| memory_retrieval_gate=memory_retrieval_gate, |
| ) |
| if packed_memory is not None: |
| packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate = packed_memory |
| x = self.memory_token_cross_attn( |
| memory_base, |
| c, |
| packed_tokens, |
| packed_mask, |
| residual_gate=residual_gate, |
| memory_type_ids=packed_type_ids, |
| memory_token_gate=packed_gate, |
| ) |
|
|
| if self.ref_mode == 'parallel': |
| x = x_t + self.parallel_map(x) |
|
|
| return x |
|
|
|
|
| class DiT(nn.Module): |
| """ |
| Diffusion model with a Transformer backbone. |
| """ |
|
|
| def __init__( |
| self, |
| input_h=18, |
| input_w=32, |
| patch_size=2, |
| in_channels=16, |
| hidden_size=1024, |
| depth=12, |
| num_heads=16, |
| mlp_ratio=4.0, |
| action_cond_dim=25, |
| max_frames=32, |
| reference_length=8, |
| memory_token_cross_attention=False, |
| memory_cross_attn_layers=None, |
| ref_mode='sequential' |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = in_channels |
| self.patch_size = patch_size |
| self.num_heads = num_heads |
| self.max_frames = max_frames |
|
|
| self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False) |
| self.t_embedder = TimestepEmbedder(hidden_size) |
|
|
| self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256) |
| self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads) |
|
|
| self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity() |
| if memory_cross_attn_layers is None: |
| memory_cross_attn_layer_set = None |
| else: |
| memory_cross_attn_layer_set = {int(layer_idx) for layer_idx in memory_cross_attn_layers} |
| invalid_layers = sorted( |
| layer_idx for layer_idx in memory_cross_attn_layer_set if layer_idx < 0 or layer_idx >= depth |
| ) |
| if invalid_layers: |
| raise ValueError( |
| f"memory_cross_attn_layers contains invalid indices {invalid_layers} for depth={depth}" |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| SpatioTemporalDiTBlock( |
| hidden_size, |
| num_heads, |
| mlp_ratio=mlp_ratio, |
| is_causal=True, |
| reference_length=reference_length, |
| spatial_rotary_emb=self.spatial_rotary_emb, |
| temporal_rotary_emb=self.temporal_rotary_emb, |
| use_memory_token_cross_attention=memory_token_cross_attention |
| and (memory_cross_attn_layer_set is None or block_idx in memory_cross_attn_layer_set), |
| ref_mode=ref_mode |
| ) |
| for block_idx in range(depth) |
| ] |
| ) |
| self.memory_token_cross_attention = memory_token_cross_attention |
| self.memory_cross_attn_layers = ( |
| None if memory_cross_attn_layer_set is None else tuple(sorted(memory_cross_attn_layer_set)) |
| ) |
| self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
|
|
| self.apply(_basic_init) |
|
|
| |
| w = self.x_embedder.proj.weight.data |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| nn.init.constant_(self.x_embedder.proj.bias, 0) |
|
|
| |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
|
| |
| for block in self.blocks: |
| nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0) |
| nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0) |
|
|
| |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
| nn.init.constant_(self.final_layer.linear.weight, 0) |
| nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
| if self.memory_token_cross_attention: |
| for block in self.blocks: |
| memory_adapter = getattr(block, "memory_token_cross_attn", None) |
| if memory_adapter is not None: |
| memory_adapter.reset_identity_init() |
|
|
| def memory_adapter_delta_diagnostics(self): |
| diagnostics = {} |
| ratios = [] |
| type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES} |
| shared_type_gate_values = [] |
| for block in self.blocks: |
| adapter = getattr(block, "memory_token_cross_attn", None) |
| if adapter is None: |
| continue |
| ratio = getattr(adapter, "last_delta_ratio", None) |
| if ratio is not None: |
| ratios.append(torch.as_tensor(ratio).detach().float()) |
| type_gate = getattr(adapter, "last_type_gate_mean", None) |
| if type_gate is not None: |
| shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float()) |
| for type_name in MEMORY_TYPE_NAMES: |
| value = getattr(adapter, f"last_type_gate_{type_name}_mean", None) |
| if value is not None: |
| type_gate_values[type_name].append(torch.as_tensor(value).detach().float()) |
| if ratios: |
| values = torch.stack(ratios) |
| diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item()) |
| diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item()) |
| if shared_type_gate_values: |
| values = torch.stack(shared_type_gate_values) |
| diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item()) |
| for type_name, values_list in type_gate_values.items(): |
| if values_list: |
| values = torch.stack(values_list) |
| diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item()) |
| return diagnostics |
|
|
| def unpatchify(self, x): |
| """ |
| x: (N, H, W, patch_size**2 * C) |
| imgs: (N, H, W, C) |
| """ |
| c = self.out_channels |
| p = self.x_embedder.patch_size[0] |
| h = x.shape[1] |
| w = x.shape[2] |
|
|
| x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
| x = torch.einsum("nhwpqc->nchpwq", x) |
| imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) |
| return imgs |
|
|
| def forward( |
| self, |
| x, |
| t, |
| action_cond=None, |
| pose_cond=None, |
| current_frame=None, |
| mode=None, |
| reference_length=None, |
| frame_idx=None, |
| memory_tokens=None, |
| memory_token_mask=None, |
| memory_dynamic_tokens=None, |
| memory_dynamic_mask=None, |
| memory_retrieval_tokens=None, |
| memory_retrieval_mask=None, |
| memory_anchor_gate=None, |
| memory_dynamic_gate=None, |
| memory_retrieval_gate=None, |
| ): |
| """ |
| Forward pass of DiT. |
| x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images) |
| t: (B, T,) tensor of diffusion timesteps |
| """ |
|
|
| B, T, C, H, W = x.shape |
|
|
| |
| x = rearrange(x, "b t c h w -> (b t) c h w") |
|
|
| x = self.x_embedder(x) |
| |
| x = rearrange(x, "(b t) h w d -> b t h w d", t=T) |
| |
| t = rearrange(t, "b t -> (b t)") |
|
|
| c_t = self.t_embedder(t) |
| c = c_t.clone() |
| c = rearrange(c, "(b t) d -> b t d", t=T) |
|
|
| if torch.is_tensor(action_cond): |
| c_action_cond = c + self.external_cond(action_cond) |
| else: |
| c_action_cond = None |
|
|
| for i, block in enumerate(self.blocks): |
| x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)), |
| mode=mode, c_action_cond=c_action_cond, reference_length=reference_length, |
| memory_tokens=memory_tokens, memory_token_mask=memory_token_mask, |
| memory_dynamic_tokens=memory_dynamic_tokens, memory_dynamic_mask=memory_dynamic_mask, |
| memory_retrieval_tokens=memory_retrieval_tokens, memory_retrieval_mask=memory_retrieval_mask, |
| memory_anchor_gate=memory_anchor_gate, memory_dynamic_gate=memory_dynamic_gate, |
| memory_retrieval_gate=memory_retrieval_gate) |
| x = self.final_layer(x, c) |
| |
| x = rearrange(x, "b t h w d -> (b t) h w d") |
| x = self.unpatchify(x) |
| x = rearrange(x, "(b t) c h w -> b t c h w", t=T) |
| return x |
|
|
|
|
| def DiT_S_2( |
| action_cond_dim, |
| reference_length, |
| ref_mode, |
| memory_token_cross_attention=False, |
| memory_cross_attn_layers=None, |
| ): |
| return DiT( |
| patch_size=2, |
| hidden_size=1024, |
| depth=16, |
| num_heads=16, |
| action_cond_dim=action_cond_dim, |
| reference_length=reference_length, |
| memory_token_cross_attention=memory_token_cross_attention, |
| memory_cross_attn_layers=memory_cross_attn_layers, |
| ref_mode=ref_mode |
| ) |
|
|
|
|
| DiT_models = {"DiT-S/2": DiT_S_2} |
|
|