| from __future__ import annotations |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor: |
| if latents.ndim != 5: |
| raise ValueError("latents must have shape (T,B,C,H,W)") |
| if patch_size <= 0: |
| raise ValueError("patch_size must be positive") |
| T, B, C, H, W = latents.shape |
| if H % patch_size != 0 or W % patch_size != 0: |
| raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}") |
| flat = latents.reshape(T * B, C, H, W) |
| patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous() |
| return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size) |
|
|
|
|
| def spatial_pool_tokens( |
| tokens: torch.Tensor, |
| pool_h: int, |
| pool_w: int, |
| src_h: int, |
| src_w: int, |
| ) -> torch.Tensor: |
| """2D adaptive average pool on a flattened (src_h*src_w, D) token grid. |
| Preserves 2D spatial layout. Returns (pool_h*pool_w, D).""" |
| if tokens.ndim != 2: |
| raise ValueError("tokens must have shape (N, D)") |
| D = tokens.shape[-1] |
| spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0) |
| pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w)) |
| return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D) |
|
|
|
|
| class SpatialConv2DMemoryProjector(nn.Module): |
| """Project latent maps to DiT hidden tokens while preserving the HxW grid.""" |
|
|
| projects_spatial_latents = True |
|
|
| def __init__( |
| self, |
| latent_channels: int, |
| dit_hidden_size: int, |
| mid_channels: int, |
| kernel_size: int = 3, |
| ): |
| super().__init__() |
| kernel_size = int(kernel_size) |
| if kernel_size <= 0 or kernel_size % 2 == 0: |
| raise ValueError("kernel_size must be a positive odd integer") |
| self.latent_channels = int(latent_channels) |
| self.dit_hidden_size = int(dit_hidden_size) |
| self.mid_channels = int(mid_channels) |
| self.kernel_size = kernel_size |
| self.out_features = self.dit_hidden_size |
| self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1) |
| self.proj_spatial = nn.Conv2d( |
| self.mid_channels, |
| self.dit_hidden_size, |
| kernel_size=kernel_size, |
| padding=kernel_size // 2, |
| ) |
|
|
| def forward(self, latents: torch.Tensor) -> torch.Tensor: |
| if latents.ndim != 5: |
| raise ValueError("latents must have shape (T,B,C,H,W)") |
| T, B, C, H, W = latents.shape |
| if C != self.latent_channels: |
| raise ValueError(f"expected {self.latent_channels} latent channels, got {C}") |
| x = latents.reshape(T * B, C, H, W) |
| x = self.proj_spatial(self.proj_in(x)) |
| x = x.reshape(T, B, self.dit_hidden_size, H, W) |
| return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous() |
|
|
|
|
| class CausalConv3DDynamicCompressor(nn.Module): |
| """Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents. |
| |
| Replaces ShortTermLatentCompressor (slot cross-attention). |
| - Operates directly on (T, C, H, W) raw latents |
| - Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1] |
| - Causal padding prepends temporal zeros and right-aligns fixed outputs |
| - Zero-padded to max_source_frames for fixed output shape |
| - No slot cross-attention, no chunking |
| """ |
|
|
| def __init__( |
| self, |
| latent_channels: int, |
| dit_hidden_size: int, |
| patch_size: int = 2, |
| conv_kernel_t: int = 3, |
| conv_stride_t: int = 2, |
| max_source_frames: int = 8, |
| exclude_latest_local_frames: int = 4, |
| ): |
| super().__init__() |
| self.latent_channels = latent_channels |
| self.dit_hidden_size = dit_hidden_size |
| self.patch_size = patch_size |
| self.conv_kernel_t = conv_kernel_t |
| self.conv_stride_t = conv_stride_t |
| self.max_source_frames = max_source_frames |
| self.exclude_latest_local_frames = int(exclude_latest_local_frames) |
| self.causal_pad = self._temporal_left_pad() |
| self.conv3d = nn.Conv3d( |
| latent_channels, dit_hidden_size, |
| kernel_size=(conv_kernel_t, patch_size, patch_size), |
| stride=(conv_stride_t, patch_size, patch_size), |
| padding=0, |
| ) |
| self.out_norm = nn.LayerNorm(dit_hidden_size) |
| self._init_temporal_as_delta() |
|
|
| def _init_temporal_as_delta(self) -> None: |
| with torch.no_grad(): |
| self.conv3d.weight.zero_() |
| k_t, p = self.conv_kernel_t, self.patch_size |
| D_out, D_in = self.conv3d.weight.shape[:2] |
| scale = 1.0 / (p * p) |
| |
| |
| |
| for d in range(D_out): |
| self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale |
| if self.conv3d.bias is not None: |
| nn.init.zeros_(self.conv3d.bias) |
|
|
| def _temporal_output_count(self) -> int: |
| return math.ceil(self.max_source_frames / self.conv_stride_t) |
|
|
| def _temporal_left_pad(self) -> int: |
| t_out = self._temporal_output_count() |
| latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1 |
| latest_source = self.max_source_frames - 1 |
| return max(0, latest_output_end - latest_source) |
|
|
| def _output_time_indices(self, device: torch.device) -> torch.Tensor: |
| t_out = self._temporal_output_count() |
| return ( |
| torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t |
| + self.conv_kernel_t |
| - 1 |
| - self.causal_pad |
| ) |
|
|
| def tokens_per_target(self, H: int, W: int) -> int: |
| p = self.patch_size |
| T_out = self._temporal_output_count() |
| return T_out * (H // p) * (W // p) |
|
|
| def forward( |
| self, |
| latents: torch.Tensor, |
| frame_indices: torch.Tensor, |
| pose: Optional[torch.Tensor], |
| target_frame_indices: torch.Tensor, |
| source_is_generated: Optional[torch.Tensor] = None, |
| exclude_latest_local_frames: Optional[int] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| del pose, source_is_generated |
| if latents.ndim != 5: |
| raise ValueError("latents must have shape (T_src,B,C,H,W)") |
| exclude_latest_local_frames = ( |
| self.exclude_latest_local_frames |
| if exclude_latest_local_frames is None |
| else int(exclude_latest_local_frames) |
| ) |
| T_src, B, C, H, W = latents.shape |
| p = self.patch_size |
| if H % p != 0 or W % p != 0: |
| raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}") |
| if frame_indices.shape != (T_src, B): |
| raise ValueError("frame_indices must have shape (T_src,B)") |
| if target_frame_indices.ndim == 1: |
| target_frame_indices = target_frame_indices[:, None].expand(-1, B) |
| if target_frame_indices.ndim != 2 or target_frame_indices.shape[1] != B: |
| raise ValueError("target_frame_indices must have shape (T_tgt,B)") |
|
|
| device = latents.device |
| frame_indices = frame_indices.to(device=device) |
| target_frame_indices = target_frame_indices.to(device=device) |
| T_tgt = target_frame_indices.shape[0] |
| n_spatial = (H // p) * (W // p) |
| T_out = self._temporal_output_count() |
| num_slots = T_out * n_spatial |
| output_time_idx = self._output_time_indices(device) |
| if T_src == 0: |
| out_tokens = latents.new_zeros((B, T_tgt, num_slots, self.dit_hidden_size)) |
| out_mask = torch.zeros((B, T_tgt, num_slots), device=device, dtype=torch.bool) |
| return out_tokens, out_mask |
|
|
| source_frames = frame_indices.transpose(0, 1).contiguous() |
| target_frames = target_frame_indices.transpose(0, 1).contiguous() |
| valid = source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames)) |
| valid_flat = valid.reshape(B * T_tgt, T_src) |
| source_frames_flat = source_frames[:, None, :].expand(B, T_tgt, T_src).reshape(B * T_tgt, T_src) |
|
|
| topk = min(int(self.max_source_frames), T_src) |
| rank = source_frames_flat.to(dtype=torch.float64).masked_fill(~valid_flat, -float("inf")) |
| top = torch.topk(rank, k=topk, dim=1, largest=True, sorted=True) |
| selected_idx = top.indices.flip(dims=(1,)) |
| selected_valid = torch.isfinite(top.values).flip(dims=(1,)) |
| if topk < self.max_source_frames: |
| pad_count = self.max_source_frames - topk |
| selected_idx = torch.cat([ |
| torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.long), |
| selected_idx, |
| ], dim=1) |
| selected_valid = torch.cat([ |
| torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.bool), |
| selected_valid, |
| ], dim=1) |
|
|
| selected_idx_clamped = selected_idx.to(device=device, dtype=torch.long).clamp(min=0, max=max(0, T_src - 1)) |
| has_valid = selected_valid.any(dim=1) |
| batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt) |
| latents_by_batch = latents.permute(1, 0, 2, 3, 4).contiguous() |
| latents_per_query = latents_by_batch.index_select(0, batch_ids) |
| gather_idx = selected_idx_clamped.reshape(B * T_tgt, self.max_source_frames, 1, 1, 1).expand( |
| -1, -1, C, H, W |
| ) |
| chunk = torch.gather(latents_per_query, 1, gather_idx) |
| chunk = torch.where( |
| selected_valid[:, :, None, None, None], |
| chunk, |
| torch.zeros((), device=device, dtype=latents.dtype), |
| ) |
|
|
| inp = chunk.clone() |
| inp[:, 1:] = chunk[:, 1:] - chunk[:, :-1] |
| x = inp.permute(0, 2, 1, 3, 4) |
| x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0)) |
| x = self.conv3d(x) |
| x = self.out_norm(x.permute(0, 2, 3, 4, 1)) |
| tokens_flat = x.reshape(B * T_tgt, num_slots, self.dit_hidden_size) |
| tokens_flat = torch.where(has_valid[:, None, None], tokens_flat, torch.zeros_like(tokens_flat)) |
| out_tokens = tokens_flat.reshape(B, T_tgt, num_slots, self.dit_hidden_size) |
|
|
| clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1) |
| temporal_mask = ( |
| (output_time_idx >= 0) |
| & (output_time_idx < self.max_source_frames) |
| & selected_valid.index_select(1, clamped_time_idx) |
| ) |
| out_mask = temporal_mask[:, :, None].expand(B * T_tgt, T_out, n_spatial).reshape(B, T_tgt, num_slots) |
| return out_tokens, out_mask |
|
|