File size: 10,922 Bytes
b47a1ce 1dae740 b47a1ce 1dae740 b47a1ce 1dae740 b47a1ce 1dae740 b47a1ce 1dae740 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor:
if latents.ndim != 5:
raise ValueError("latents must have shape (T,B,C,H,W)")
if patch_size <= 0:
raise ValueError("patch_size must be positive")
T, B, C, H, W = latents.shape
if H % patch_size != 0 or W % patch_size != 0:
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}")
flat = latents.reshape(T * B, C, H, W)
patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous()
return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size)
def spatial_pool_tokens(
tokens: torch.Tensor,
pool_h: int,
pool_w: int,
src_h: int,
src_w: int,
) -> torch.Tensor:
"""2D adaptive average pool on a flattened (src_h*src_w, D) token grid.
Preserves 2D spatial layout. Returns (pool_h*pool_w, D)."""
if tokens.ndim != 2:
raise ValueError("tokens must have shape (N, D)")
D = tokens.shape[-1]
spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0)
pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w))
return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D)
class SpatialConv2DMemoryProjector(nn.Module):
"""Project latent maps to DiT hidden tokens while preserving the HxW grid."""
projects_spatial_latents = True
def __init__(
self,
latent_channels: int,
dit_hidden_size: int,
mid_channels: int,
kernel_size: int = 3,
):
super().__init__()
kernel_size = int(kernel_size)
if kernel_size <= 0 or kernel_size % 2 == 0:
raise ValueError("kernel_size must be a positive odd integer")
self.latent_channels = int(latent_channels)
self.dit_hidden_size = int(dit_hidden_size)
self.mid_channels = int(mid_channels)
self.kernel_size = kernel_size
self.out_features = self.dit_hidden_size
self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1)
self.proj_spatial = nn.Conv2d(
self.mid_channels,
self.dit_hidden_size,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
def forward(self, latents: torch.Tensor) -> torch.Tensor:
if latents.ndim != 5:
raise ValueError("latents must have shape (T,B,C,H,W)")
T, B, C, H, W = latents.shape
if C != self.latent_channels:
raise ValueError(f"expected {self.latent_channels} latent channels, got {C}")
x = latents.reshape(T * B, C, H, W)
x = self.proj_spatial(self.proj_in(x))
x = x.reshape(T, B, self.dit_hidden_size, H, W)
return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous()
class CausalConv3DDynamicCompressor(nn.Module):
"""Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents.
Replaces ShortTermLatentCompressor (slot cross-attention).
- Operates directly on (T, C, H, W) raw latents
- Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1]
- Causal padding prepends temporal zeros and right-aligns fixed outputs
- Zero-padded to max_source_frames for fixed output shape
- No slot cross-attention, no chunking
"""
def __init__(
self,
latent_channels: int,
dit_hidden_size: int,
patch_size: int = 2,
conv_kernel_t: int = 3,
conv_stride_t: int = 2,
max_source_frames: int = 8,
exclude_latest_local_frames: int = 4,
):
super().__init__()
self.latent_channels = latent_channels
self.dit_hidden_size = dit_hidden_size
self.patch_size = patch_size
self.conv_kernel_t = conv_kernel_t
self.conv_stride_t = conv_stride_t
self.max_source_frames = max_source_frames
self.exclude_latest_local_frames = int(exclude_latest_local_frames)
self.causal_pad = self._temporal_left_pad()
self.conv3d = nn.Conv3d(
latent_channels, dit_hidden_size,
kernel_size=(conv_kernel_t, patch_size, patch_size),
stride=(conv_stride_t, patch_size, patch_size),
padding=0,
)
self.out_norm = nn.LayerNorm(dit_hidden_size)
self._init_temporal_as_delta()
def _init_temporal_as_delta(self) -> None:
with torch.no_grad():
self.conv3d.weight.zero_()
k_t, p = self.conv_kernel_t, self.patch_size
D_out, D_in = self.conv3d.weight.shape[:2]
scale = 1.0 / (p * p)
# Delta preprocessing happens in forward. Initialize every output
# channel to read a patch-averaged current delta, repeating latent
# channels across the wider DiT hidden dimension.
for d in range(D_out):
self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale
if self.conv3d.bias is not None:
nn.init.zeros_(self.conv3d.bias)
def _temporal_output_count(self) -> int:
return math.ceil(self.max_source_frames / self.conv_stride_t)
def _temporal_left_pad(self) -> int:
t_out = self._temporal_output_count()
latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1
latest_source = self.max_source_frames - 1
return max(0, latest_output_end - latest_source)
def _output_time_indices(self, device: torch.device) -> torch.Tensor:
t_out = self._temporal_output_count()
return (
torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t
+ self.conv_kernel_t
- 1
- self.causal_pad
)
def tokens_per_target(self, H: int, W: int) -> int:
p = self.patch_size
T_out = self._temporal_output_count()
return T_out * (H // p) * (W // p)
def forward(
self,
latents: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor],
target_frame_indices: torch.Tensor,
source_is_generated: Optional[torch.Tensor] = None,
exclude_latest_local_frames: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
del pose, source_is_generated
if latents.ndim != 5:
raise ValueError("latents must have shape (T_src,B,C,H,W)")
exclude_latest_local_frames = (
self.exclude_latest_local_frames
if exclude_latest_local_frames is None
else int(exclude_latest_local_frames)
)
T_src, B, C, H, W = latents.shape
p = self.patch_size
if H % p != 0 or W % p != 0:
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
if frame_indices.shape != (T_src, B):
raise ValueError("frame_indices must have shape (T_src,B)")
if target_frame_indices.ndim == 1:
target_frame_indices = target_frame_indices[:, None].expand(-1, B)
if target_frame_indices.ndim != 2 or target_frame_indices.shape[1] != B:
raise ValueError("target_frame_indices must have shape (T_tgt,B)")
device = latents.device
frame_indices = frame_indices.to(device=device)
target_frame_indices = target_frame_indices.to(device=device)
T_tgt = target_frame_indices.shape[0]
n_spatial = (H // p) * (W // p)
T_out = self._temporal_output_count()
num_slots = T_out * n_spatial
output_time_idx = self._output_time_indices(device)
if T_src == 0:
out_tokens = latents.new_zeros((B, T_tgt, num_slots, self.dit_hidden_size))
out_mask = torch.zeros((B, T_tgt, num_slots), device=device, dtype=torch.bool)
return out_tokens, out_mask
source_frames = frame_indices.transpose(0, 1).contiguous()
target_frames = target_frame_indices.transpose(0, 1).contiguous()
valid = source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames))
valid_flat = valid.reshape(B * T_tgt, T_src)
source_frames_flat = source_frames[:, None, :].expand(B, T_tgt, T_src).reshape(B * T_tgt, T_src)
topk = min(int(self.max_source_frames), T_src)
rank = source_frames_flat.to(dtype=torch.float64).masked_fill(~valid_flat, -float("inf"))
top = torch.topk(rank, k=topk, dim=1, largest=True, sorted=True)
selected_idx = top.indices.flip(dims=(1,))
selected_valid = torch.isfinite(top.values).flip(dims=(1,))
if topk < self.max_source_frames:
pad_count = self.max_source_frames - topk
selected_idx = torch.cat([
torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.long),
selected_idx,
], dim=1)
selected_valid = torch.cat([
torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.bool),
selected_valid,
], dim=1)
selected_idx_clamped = selected_idx.to(device=device, dtype=torch.long).clamp(min=0, max=max(0, T_src - 1))
has_valid = selected_valid.any(dim=1)
batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt)
latents_by_batch = latents.permute(1, 0, 2, 3, 4).contiguous()
latents_per_query = latents_by_batch.index_select(0, batch_ids)
gather_idx = selected_idx_clamped.reshape(B * T_tgt, self.max_source_frames, 1, 1, 1).expand(
-1, -1, C, H, W
)
chunk = torch.gather(latents_per_query, 1, gather_idx)
chunk = torch.where(
selected_valid[:, :, None, None, None],
chunk,
torch.zeros((), device=device, dtype=latents.dtype),
)
inp = chunk.clone()
inp[:, 1:] = chunk[:, 1:] - chunk[:, :-1]
x = inp.permute(0, 2, 1, 3, 4)
x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0))
x = self.conv3d(x)
x = self.out_norm(x.permute(0, 2, 3, 4, 1))
tokens_flat = x.reshape(B * T_tgt, num_slots, self.dit_hidden_size)
tokens_flat = torch.where(has_valid[:, None, None], tokens_flat, torch.zeros_like(tokens_flat))
out_tokens = tokens_flat.reshape(B, T_tgt, num_slots, self.dit_hidden_size)
clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
temporal_mask = (
(output_time_idx >= 0)
& (output_time_idx < self.max_source_frames)
& selected_valid.index_select(1, clamped_time_idx)
)
out_mask = temporal_mask[:, :, None].expand(B * T_tgt, T_out, n_spatial).reshape(B, T_tgt, num_slots)
return out_tokens, out_mask
|