from __future__ import annotations import math from typing import Optional import torch import torch.nn.functional as F from torch import nn def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor: if latents.ndim != 5: raise ValueError("latents must have shape (T,B,C,H,W)") if patch_size <= 0: raise ValueError("patch_size must be positive") T, B, C, H, W = latents.shape if H % patch_size != 0 or W % patch_size != 0: raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}") flat = latents.reshape(T * B, C, H, W) patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous() return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size) def spatial_pool_tokens( tokens: torch.Tensor, pool_h: int, pool_w: int, src_h: int, src_w: int, ) -> torch.Tensor: """2D adaptive average pool on a flattened (src_h*src_w, D) token grid. Preserves 2D spatial layout. Returns (pool_h*pool_w, D).""" if tokens.ndim != 2: raise ValueError("tokens must have shape (N, D)") D = tokens.shape[-1] spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0) pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w)) return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D) class SpatialConv2DMemoryProjector(nn.Module): """Project latent maps to DiT hidden tokens while preserving the HxW grid.""" projects_spatial_latents = True def __init__( self, latent_channels: int, dit_hidden_size: int, mid_channels: int, kernel_size: int = 3, ): super().__init__() kernel_size = int(kernel_size) if kernel_size <= 0 or kernel_size % 2 == 0: raise ValueError("kernel_size must be a positive odd integer") self.latent_channels = int(latent_channels) self.dit_hidden_size = int(dit_hidden_size) self.mid_channels = int(mid_channels) self.kernel_size = kernel_size self.out_features = self.dit_hidden_size self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1) self.proj_spatial = nn.Conv2d( self.mid_channels, self.dit_hidden_size, kernel_size=kernel_size, padding=kernel_size // 2, ) def forward(self, latents: torch.Tensor) -> torch.Tensor: if latents.ndim != 5: raise ValueError("latents must have shape (T,B,C,H,W)") T, B, C, H, W = latents.shape if C != self.latent_channels: raise ValueError(f"expected {self.latent_channels} latent channels, got {C}") x = latents.reshape(T * B, C, H, W) x = self.proj_spatial(self.proj_in(x)) x = x.reshape(T, B, self.dit_hidden_size, H, W) return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous() class CausalConv3DDynamicCompressor(nn.Module): """Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents. Replaces ShortTermLatentCompressor (slot cross-attention). - Operates directly on (T, C, H, W) raw latents - Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1] - Causal padding prepends temporal zeros and right-aligns fixed outputs - Zero-padded to max_source_frames for fixed output shape - No slot cross-attention, no chunking """ def __init__( self, latent_channels: int, dit_hidden_size: int, patch_size: int = 2, conv_kernel_t: int = 3, conv_stride_t: int = 2, max_source_frames: int = 8, exclude_latest_local_frames: int = 4, ): super().__init__() self.latent_channels = latent_channels self.dit_hidden_size = dit_hidden_size self.patch_size = patch_size self.conv_kernel_t = conv_kernel_t self.conv_stride_t = conv_stride_t self.max_source_frames = max_source_frames self.exclude_latest_local_frames = int(exclude_latest_local_frames) self.causal_pad = self._temporal_left_pad() self.conv3d = nn.Conv3d( latent_channels, dit_hidden_size, kernel_size=(conv_kernel_t, patch_size, patch_size), stride=(conv_stride_t, patch_size, patch_size), padding=0, ) self.out_norm = nn.LayerNorm(dit_hidden_size) self._init_temporal_as_delta() def _init_temporal_as_delta(self) -> None: with torch.no_grad(): self.conv3d.weight.zero_() k_t, p = self.conv_kernel_t, self.patch_size D_out, D_in = self.conv3d.weight.shape[:2] scale = 1.0 / (p * p) # Delta preprocessing happens in forward. Initialize every output # channel to read a patch-averaged current delta, repeating latent # channels across the wider DiT hidden dimension. for d in range(D_out): self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale if self.conv3d.bias is not None: nn.init.zeros_(self.conv3d.bias) def _temporal_output_count(self) -> int: return math.ceil(self.max_source_frames / self.conv_stride_t) def _temporal_left_pad(self) -> int: t_out = self._temporal_output_count() latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1 latest_source = self.max_source_frames - 1 return max(0, latest_output_end - latest_source) def _output_time_indices(self, device: torch.device) -> torch.Tensor: t_out = self._temporal_output_count() return ( torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t + self.conv_kernel_t - 1 - self.causal_pad ) def tokens_per_target(self, H: int, W: int) -> int: p = self.patch_size T_out = self._temporal_output_count() return T_out * (H // p) * (W // p) def forward( self, latents: torch.Tensor, frame_indices: torch.Tensor, pose: Optional[torch.Tensor], target_frame_indices: torch.Tensor, source_is_generated: Optional[torch.Tensor] = None, exclude_latest_local_frames: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor]: del pose, source_is_generated if latents.ndim != 5: raise ValueError("latents must have shape (T_src,B,C,H,W)") exclude_latest_local_frames = ( self.exclude_latest_local_frames if exclude_latest_local_frames is None else int(exclude_latest_local_frames) ) T_src, B, C, H, W = latents.shape p = self.patch_size if H % p != 0 or W % p != 0: raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}") if frame_indices.shape != (T_src, B): raise ValueError("frame_indices must have shape (T_src,B)") if target_frame_indices.ndim == 1: target_frame_indices = target_frame_indices[:, None].expand(-1, B) if target_frame_indices.ndim != 2 or target_frame_indices.shape[1] != B: raise ValueError("target_frame_indices must have shape (T_tgt,B)") device = latents.device frame_indices = frame_indices.to(device=device) target_frame_indices = target_frame_indices.to(device=device) T_tgt = target_frame_indices.shape[0] n_spatial = (H // p) * (W // p) T_out = self._temporal_output_count() num_slots = T_out * n_spatial output_time_idx = self._output_time_indices(device) if T_src == 0: out_tokens = latents.new_zeros((B, T_tgt, num_slots, self.dit_hidden_size)) out_mask = torch.zeros((B, T_tgt, num_slots), device=device, dtype=torch.bool) return out_tokens, out_mask source_frames = frame_indices.transpose(0, 1).contiguous() target_frames = target_frame_indices.transpose(0, 1).contiguous() valid = source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames)) valid_flat = valid.reshape(B * T_tgt, T_src) source_frames_flat = source_frames[:, None, :].expand(B, T_tgt, T_src).reshape(B * T_tgt, T_src) topk = min(int(self.max_source_frames), T_src) rank = source_frames_flat.to(dtype=torch.float64).masked_fill(~valid_flat, -float("inf")) top = torch.topk(rank, k=topk, dim=1, largest=True, sorted=True) selected_idx = top.indices.flip(dims=(1,)) selected_valid = torch.isfinite(top.values).flip(dims=(1,)) if topk < self.max_source_frames: pad_count = self.max_source_frames - topk selected_idx = torch.cat([ torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.long), selected_idx, ], dim=1) selected_valid = torch.cat([ torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.bool), selected_valid, ], dim=1) selected_idx_clamped = selected_idx.to(device=device, dtype=torch.long).clamp(min=0, max=max(0, T_src - 1)) has_valid = selected_valid.any(dim=1) batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt) latents_by_batch = latents.permute(1, 0, 2, 3, 4).contiguous() latents_per_query = latents_by_batch.index_select(0, batch_ids) gather_idx = selected_idx_clamped.reshape(B * T_tgt, self.max_source_frames, 1, 1, 1).expand( -1, -1, C, H, W ) chunk = torch.gather(latents_per_query, 1, gather_idx) chunk = torch.where( selected_valid[:, :, None, None, None], chunk, torch.zeros((), device=device, dtype=latents.dtype), ) inp = chunk.clone() inp[:, 1:] = chunk[:, 1:] - chunk[:, :-1] x = inp.permute(0, 2, 1, 3, 4) x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0)) x = self.conv3d(x) x = self.out_norm(x.permute(0, 2, 3, 4, 1)) tokens_flat = x.reshape(B * T_tgt, num_slots, self.dit_hidden_size) tokens_flat = torch.where(has_valid[:, None, None], tokens_flat, torch.zeros_like(tokens_flat)) out_tokens = tokens_flat.reshape(B, T_tgt, num_slots, self.dit_hidden_size) clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1) temporal_mask = ( (output_time_idx >= 0) & (output_time_idx < self.max_source_frames) & selected_valid.index_select(1, clamped_time_idx) ) out_mask = temporal_mask[:, :, None].expand(B * T_tgt, T_out, n_spatial).reshape(B, T_tgt, num_slots) return out_tokens, out_mask