from __future__ import annotations from dataclasses import dataclass from typing import Any import torch from torch import nn from .frame_memory import FrameMemory from .latent_video_init import LatentVideoInitializer from .motion_tokens import MotionTokenBank from .temporal_adapters import TemporalAdapter, TrainableTemporalMotionAdapter from .temporal_unet_wrapper import SharedUNetVideoWrapper @dataclass class FusedVideoStackOutput: video: torch.Tensor metadata: dict[str, Any] class FusedVideoStack(nn.Module): def __init__( self, image_pipe: Any, *, channels: int = 3, motion_dim: int = 32, adapter_profile: str = "safe", adapter_hidden: int = 128, adapter_layers: int = 3, transfer_init: bool = False, transfer_strength: float = 0.25, ): super().__init__() object.__setattr__(self, "image_pipe", image_pipe) self.adapter_profile = adapter_profile self.shared_unet = SharedUNetVideoWrapper(getattr(image_pipe.adapter, "image_generator", None)) self.motion_tokens = MotionTokenBank(dim=motion_dim) self.frame_memory = FrameMemory(channels=channels, memory_dim=motion_dim) if adapter_profile == "trainable_v1_1": self.temporal_adapter = TrainableTemporalMotionAdapter( channels=channels, hidden=adapter_hidden, motion_dim=motion_dim, layers=adapter_layers, ) if transfer_init: self.temporal_adapter.initialize_transfer_weights(strength=transfer_strength) else: self.temporal_adapter = TemporalAdapter(channels=channels) self.initializer = LatentVideoInitializer() def forward(self, anchor: torch.Tensor, *, frames: int, motion: str, seed: int | None = None) -> FusedVideoStackOutput: video = self.initializer.from_anchor(anchor, frames=frames, seed=seed) motion_tokens = self.motion_tokens(motion, batch=video.shape[0], device=video.device) frame_memory = self.frame_memory(video) shared = self.shared_unet.forward_shared(video) fused = self.temporal_adapter(shared, motion_tokens=motion_tokens, frame_memory=frame_memory) return FusedVideoStackOutput( video=fused.clamp(0, 1), metadata={ "base_unet_shared": self.shared_unet.shares_with(getattr(self.image_pipe.adapter, "image_generator", None)), "temporal_adapter_zero_init": bool(torch.allclose(self.temporal_adapter.gate.detach(), torch.zeros_like(self.temporal_adapter.gate.detach()))), "temporal_adapter_gate": float(self.temporal_adapter.gate.detach().float().cpu()), "transfer_initialized": bool(self.adapter_profile == "trainable_v1_1" and not torch.allclose(self.temporal_adapter.gate.detach(), torch.zeros_like(self.temporal_adapter.gate.detach()))), "parameter_overhead": self.parameter_overhead, "adapter_profile": self.adapter_profile, "motion_token_shape": tuple(motion_tokens.shape), "frame_memory_shape": tuple(frame_memory.shape), }, ) @property def parameter_overhead(self) -> int: return sum(param.numel() for param in self.parameters())