File size: 3,358 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import torch
from torch import nn

from .frame_memory import FrameMemory
from .latent_video_init import LatentVideoInitializer
from .motion_tokens import MotionTokenBank
from .temporal_adapters import TemporalAdapter, TrainableTemporalMotionAdapter
from .temporal_unet_wrapper import SharedUNetVideoWrapper


@dataclass
class FusedVideoStackOutput:
    video: torch.Tensor
    metadata: dict[str, Any]


class FusedVideoStack(nn.Module):
    def __init__(
        self,
        image_pipe: Any,
        *,
        channels: int = 3,
        motion_dim: int = 32,
        adapter_profile: str = "safe",
        adapter_hidden: int = 128,
        adapter_layers: int = 3,
        transfer_init: bool = False,
        transfer_strength: float = 0.25,
    ):
        super().__init__()
        object.__setattr__(self, "image_pipe", image_pipe)
        self.adapter_profile = adapter_profile
        self.shared_unet = SharedUNetVideoWrapper(getattr(image_pipe.adapter, "image_generator", None))
        self.motion_tokens = MotionTokenBank(dim=motion_dim)
        self.frame_memory = FrameMemory(channels=channels, memory_dim=motion_dim)
        if adapter_profile == "trainable_v1_1":
            self.temporal_adapter = TrainableTemporalMotionAdapter(
                channels=channels,
                hidden=adapter_hidden,
                motion_dim=motion_dim,
                layers=adapter_layers,
            )
            if transfer_init:
                self.temporal_adapter.initialize_transfer_weights(strength=transfer_strength)
        else:
            self.temporal_adapter = TemporalAdapter(channels=channels)
        self.initializer = LatentVideoInitializer()

    def forward(self, anchor: torch.Tensor, *, frames: int, motion: str, seed: int | None = None) -> FusedVideoStackOutput:
        video = self.initializer.from_anchor(anchor, frames=frames, seed=seed)
        motion_tokens = self.motion_tokens(motion, batch=video.shape[0], device=video.device)
        frame_memory = self.frame_memory(video)
        shared = self.shared_unet.forward_shared(video)
        fused = self.temporal_adapter(shared, motion_tokens=motion_tokens, frame_memory=frame_memory)
        return FusedVideoStackOutput(
            video=fused.clamp(0, 1),
            metadata={
                "base_unet_shared": self.shared_unet.shares_with(getattr(self.image_pipe.adapter, "image_generator", None)),
                "temporal_adapter_zero_init": bool(torch.allclose(self.temporal_adapter.gate.detach(), torch.zeros_like(self.temporal_adapter.gate.detach()))),
                "temporal_adapter_gate": float(self.temporal_adapter.gate.detach().float().cpu()),
                "transfer_initialized": bool(self.adapter_profile == "trainable_v1_1" and not torch.allclose(self.temporal_adapter.gate.detach(), torch.zeros_like(self.temporal_adapter.gate.detach()))),
                "parameter_overhead": self.parameter_overhead,
                "adapter_profile": self.adapter_profile,
                "motion_token_shape": tuple(motion_tokens.shape),
                "frame_memory_shape": tuple(frame_memory.shape),
            },
        )

    @property
    def parameter_overhead(self) -> int:
        return sum(param.numel() for param in self.parameters())