Phillnet-2 / VideoGen /fused_video_stack.py
ayjays132's picture
Upload 478 files
101858b verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torch import nn
from .frame_memory import FrameMemory
from .latent_video_init import LatentVideoInitializer
from .motion_tokens import MotionTokenBank
from .temporal_adapters import TemporalAdapter, TrainableTemporalMotionAdapter
from .temporal_unet_wrapper import SharedUNetVideoWrapper
@dataclass
class FusedVideoStackOutput:
video: torch.Tensor
metadata: dict[str, Any]
class FusedVideoStack(nn.Module):
def __init__(
self,
image_pipe: Any,
*,
channels: int = 3,
motion_dim: int = 32,
adapter_profile: str = "safe",
adapter_hidden: int = 128,
adapter_layers: int = 3,
transfer_init: bool = False,
transfer_strength: float = 0.25,
):
super().__init__()
object.__setattr__(self, "image_pipe", image_pipe)
self.adapter_profile = adapter_profile
self.shared_unet = SharedUNetVideoWrapper(getattr(image_pipe.adapter, "image_generator", None))
self.motion_tokens = MotionTokenBank(dim=motion_dim)
self.frame_memory = FrameMemory(channels=channels, memory_dim=motion_dim)
if adapter_profile == "trainable_v1_1":
self.temporal_adapter = TrainableTemporalMotionAdapter(
channels=channels,
hidden=adapter_hidden,
motion_dim=motion_dim,
layers=adapter_layers,
)
if transfer_init:
self.temporal_adapter.initialize_transfer_weights(strength=transfer_strength)
else:
self.temporal_adapter = TemporalAdapter(channels=channels)
self.initializer = LatentVideoInitializer()
def forward(self, anchor: torch.Tensor, *, frames: int, motion: str, seed: int | None = None) -> FusedVideoStackOutput:
video = self.initializer.from_anchor(anchor, frames=frames, seed=seed)
motion_tokens = self.motion_tokens(motion, batch=video.shape[0], device=video.device)
frame_memory = self.frame_memory(video)
shared = self.shared_unet.forward_shared(video)
fused = self.temporal_adapter(shared, motion_tokens=motion_tokens, frame_memory=frame_memory)
return FusedVideoStackOutput(
video=fused.clamp(0, 1),
metadata={
"base_unet_shared": self.shared_unet.shares_with(getattr(self.image_pipe.adapter, "image_generator", None)),
"temporal_adapter_zero_init": bool(torch.allclose(self.temporal_adapter.gate.detach(), torch.zeros_like(self.temporal_adapter.gate.detach()))),
"temporal_adapter_gate": float(self.temporal_adapter.gate.detach().float().cpu()),
"transfer_initialized": bool(self.adapter_profile == "trainable_v1_1" and not torch.allclose(self.temporal_adapter.gate.detach(), torch.zeros_like(self.temporal_adapter.gate.detach()))),
"parameter_overhead": self.parameter_overhead,
"adapter_profile": self.adapter_profile,
"motion_token_shape": tuple(motion_tokens.shape),
"frame_memory_shape": tuple(frame_memory.shape),
},
)
@property
def parameter_overhead(self) -> int:
return sum(param.numel() for param in self.parameters())