DeMemWM / algorithms /worldmem /dememwm /compression.py
BonanDing's picture
Optimize DeMemWM memory retrieval and remove diagnostics
1dae740
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor:
if latents.ndim != 5:
raise ValueError("latents must have shape (T,B,C,H,W)")
if patch_size <= 0:
raise ValueError("patch_size must be positive")
T, B, C, H, W = latents.shape
if H % patch_size != 0 or W % patch_size != 0:
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}")
flat = latents.reshape(T * B, C, H, W)
patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous()
return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size)
def spatial_pool_tokens(
tokens: torch.Tensor,
pool_h: int,
pool_w: int,
src_h: int,
src_w: int,
) -> torch.Tensor:
"""2D adaptive average pool on a flattened (src_h*src_w, D) token grid.
Preserves 2D spatial layout. Returns (pool_h*pool_w, D)."""
if tokens.ndim != 2:
raise ValueError("tokens must have shape (N, D)")
D = tokens.shape[-1]
spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0)
pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w))
return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D)
class SpatialConv2DMemoryProjector(nn.Module):
"""Project latent maps to DiT hidden tokens while preserving the HxW grid."""
projects_spatial_latents = True
def __init__(
self,
latent_channels: int,
dit_hidden_size: int,
mid_channels: int,
kernel_size: int = 3,
):
super().__init__()
kernel_size = int(kernel_size)
if kernel_size <= 0 or kernel_size % 2 == 0:
raise ValueError("kernel_size must be a positive odd integer")
self.latent_channels = int(latent_channels)
self.dit_hidden_size = int(dit_hidden_size)
self.mid_channels = int(mid_channels)
self.kernel_size = kernel_size
self.out_features = self.dit_hidden_size
self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1)
self.proj_spatial = nn.Conv2d(
self.mid_channels,
self.dit_hidden_size,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
def forward(self, latents: torch.Tensor) -> torch.Tensor:
if latents.ndim != 5:
raise ValueError("latents must have shape (T,B,C,H,W)")
T, B, C, H, W = latents.shape
if C != self.latent_channels:
raise ValueError(f"expected {self.latent_channels} latent channels, got {C}")
x = latents.reshape(T * B, C, H, W)
x = self.proj_spatial(self.proj_in(x))
x = x.reshape(T, B, self.dit_hidden_size, H, W)
return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous()
class CausalConv3DDynamicCompressor(nn.Module):
"""Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents.
Replaces ShortTermLatentCompressor (slot cross-attention).
- Operates directly on (T, C, H, W) raw latents
- Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1]
- Causal padding prepends temporal zeros and right-aligns fixed outputs
- Zero-padded to max_source_frames for fixed output shape
- No slot cross-attention, no chunking
"""
def __init__(
self,
latent_channels: int,
dit_hidden_size: int,
patch_size: int = 2,
conv_kernel_t: int = 3,
conv_stride_t: int = 2,
max_source_frames: int = 8,
exclude_latest_local_frames: int = 4,
):
super().__init__()
self.latent_channels = latent_channels
self.dit_hidden_size = dit_hidden_size
self.patch_size = patch_size
self.conv_kernel_t = conv_kernel_t
self.conv_stride_t = conv_stride_t
self.max_source_frames = max_source_frames
self.exclude_latest_local_frames = int(exclude_latest_local_frames)
self.causal_pad = self._temporal_left_pad()
self.conv3d = nn.Conv3d(
latent_channels, dit_hidden_size,
kernel_size=(conv_kernel_t, patch_size, patch_size),
stride=(conv_stride_t, patch_size, patch_size),
padding=0,
)
self.out_norm = nn.LayerNorm(dit_hidden_size)
self._init_temporal_as_delta()
def _init_temporal_as_delta(self) -> None:
with torch.no_grad():
self.conv3d.weight.zero_()
k_t, p = self.conv_kernel_t, self.patch_size
D_out, D_in = self.conv3d.weight.shape[:2]
scale = 1.0 / (p * p)
# Delta preprocessing happens in forward. Initialize every output
# channel to read a patch-averaged current delta, repeating latent
# channels across the wider DiT hidden dimension.
for d in range(D_out):
self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale
if self.conv3d.bias is not None:
nn.init.zeros_(self.conv3d.bias)
def _temporal_output_count(self) -> int:
return math.ceil(self.max_source_frames / self.conv_stride_t)
def _temporal_left_pad(self) -> int:
t_out = self._temporal_output_count()
latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1
latest_source = self.max_source_frames - 1
return max(0, latest_output_end - latest_source)
def _output_time_indices(self, device: torch.device) -> torch.Tensor:
t_out = self._temporal_output_count()
return (
torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t
+ self.conv_kernel_t
- 1
- self.causal_pad
)
def tokens_per_target(self, H: int, W: int) -> int:
p = self.patch_size
T_out = self._temporal_output_count()
return T_out * (H // p) * (W // p)
def forward(
self,
latents: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor],
target_frame_indices: torch.Tensor,
source_is_generated: Optional[torch.Tensor] = None,
exclude_latest_local_frames: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
del pose, source_is_generated
if latents.ndim != 5:
raise ValueError("latents must have shape (T_src,B,C,H,W)")
exclude_latest_local_frames = (
self.exclude_latest_local_frames
if exclude_latest_local_frames is None
else int(exclude_latest_local_frames)
)
T_src, B, C, H, W = latents.shape
p = self.patch_size
if H % p != 0 or W % p != 0:
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
if frame_indices.shape != (T_src, B):
raise ValueError("frame_indices must have shape (T_src,B)")
if target_frame_indices.ndim == 1:
target_frame_indices = target_frame_indices[:, None].expand(-1, B)
if target_frame_indices.ndim != 2 or target_frame_indices.shape[1] != B:
raise ValueError("target_frame_indices must have shape (T_tgt,B)")
device = latents.device
frame_indices = frame_indices.to(device=device)
target_frame_indices = target_frame_indices.to(device=device)
T_tgt = target_frame_indices.shape[0]
n_spatial = (H // p) * (W // p)
T_out = self._temporal_output_count()
num_slots = T_out * n_spatial
output_time_idx = self._output_time_indices(device)
if T_src == 0:
out_tokens = latents.new_zeros((B, T_tgt, num_slots, self.dit_hidden_size))
out_mask = torch.zeros((B, T_tgt, num_slots), device=device, dtype=torch.bool)
return out_tokens, out_mask
source_frames = frame_indices.transpose(0, 1).contiguous()
target_frames = target_frame_indices.transpose(0, 1).contiguous()
valid = source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames))
valid_flat = valid.reshape(B * T_tgt, T_src)
source_frames_flat = source_frames[:, None, :].expand(B, T_tgt, T_src).reshape(B * T_tgt, T_src)
topk = min(int(self.max_source_frames), T_src)
rank = source_frames_flat.to(dtype=torch.float64).masked_fill(~valid_flat, -float("inf"))
top = torch.topk(rank, k=topk, dim=1, largest=True, sorted=True)
selected_idx = top.indices.flip(dims=(1,))
selected_valid = torch.isfinite(top.values).flip(dims=(1,))
if topk < self.max_source_frames:
pad_count = self.max_source_frames - topk
selected_idx = torch.cat([
torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.long),
selected_idx,
], dim=1)
selected_valid = torch.cat([
torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.bool),
selected_valid,
], dim=1)
selected_idx_clamped = selected_idx.to(device=device, dtype=torch.long).clamp(min=0, max=max(0, T_src - 1))
has_valid = selected_valid.any(dim=1)
batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt)
latents_by_batch = latents.permute(1, 0, 2, 3, 4).contiguous()
latents_per_query = latents_by_batch.index_select(0, batch_ids)
gather_idx = selected_idx_clamped.reshape(B * T_tgt, self.max_source_frames, 1, 1, 1).expand(
-1, -1, C, H, W
)
chunk = torch.gather(latents_per_query, 1, gather_idx)
chunk = torch.where(
selected_valid[:, :, None, None, None],
chunk,
torch.zeros((), device=device, dtype=latents.dtype),
)
inp = chunk.clone()
inp[:, 1:] = chunk[:, 1:] - chunk[:, :-1]
x = inp.permute(0, 2, 1, 3, 4)
x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0))
x = self.conv3d(x)
x = self.out_norm(x.permute(0, 2, 3, 4, 1))
tokens_flat = x.reshape(B * T_tgt, num_slots, self.dit_hidden_size)
tokens_flat = torch.where(has_valid[:, None, None], tokens_flat, torch.zeros_like(tokens_flat))
out_tokens = tokens_flat.reshape(B, T_tgt, num_slots, self.dit_hidden_size)
clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
temporal_mask = (
(output_time_idx >= 0)
& (output_time_idx < self.max_source_frames)
& selected_valid.index_select(1, clamped_time_idx)
)
out_mask = temporal_mask[:, :, None].expand(B * T_tgt, T_out, n_spatial).reshape(B, T_tgt, num_slots)
return out_tokens, out_mask