File size: 6,994 Bytes
da7bf91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from __future__ import annotations

import torch
import torch.nn as nn
import einops

from graphwm.config_graph import GraphWMArgs
from graphwm.models.graph_encoder_pyg import GraphSpatialEncoder
from graphwm.models.graph_resampler import GraphResampler
from graphwm.models.temporal_graph_conditioner import TemporalGraphConditioner
from graphwm.original_ctrl_world import import_original_modules


class GraphConditioner(nn.Module):
    """Per-frame PyG encoder -> fixed-K graph tokens -> temporal transformer."""

    def __init__(self, args: GraphWMArgs):
        super().__init__()
        self.spatial = GraphSpatialEncoder(
            node_in_dim=args.graph_in_dim,
            edge_in_dim=args.edge_in_dim,
            hidden_dim=args.graph_hidden_dim,
            num_layers=args.graph_num_layers,
            dropout=args.graph_dropout,
            backbone=args.graph_backbone,
            num_heads=args.graph_num_heads,
        )
        self.resampler = GraphResampler(
            hidden_dim=args.graph_hidden_dim,
            num_tokens=args.graph_num_tokens,
            num_heads=args.graph_num_heads,
            dropout=args.graph_dropout,
        )
        self.temporal = TemporalGraphConditioner(
            hidden_dim=args.graph_hidden_dim,
            cond_dim=args.graph_cond_dim,
            num_layers=args.graph_temporal_layers,
            num_heads=args.graph_temporal_heads,
            dropout=args.graph_dropout,
        )

    def forward(self, graph_seq):
        per_frame_tokens = []
        for graph_batch in graph_seq:
            node_tokens = self.spatial(graph_batch)
            frame_tokens = self.resampler(node_tokens, graph_batch.batch)
            per_frame_tokens.append(frame_tokens)

        frame_tokens = torch.stack(per_frame_tokens, dim=1)
        return self.temporal(frame_tokens)


class CtrlWorldGraph(nn.Module):
    """Graph-conditioned wrapper around the original Ctrl-World backbone."""

    def __init__(self, args: GraphWMArgs):
        super().__init__()
        self.args = args

        original = import_original_modules(args.ctrl_world_root)
        StableVideoDiffusionPipeline = original["StableVideoDiffusionPipeline"]
        UNetSpatioTemporalConditionModel = original["UNetSpatioTemporalConditionModel"]

        self.pipeline = StableVideoDiffusionPipeline.from_pretrained(args.svd_model_path)
        unet = UNetSpatioTemporalConditionModel()
        unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False)
        self.pipeline.unet = unet

        self.unet = self.pipeline.unet
        self.vae = self.pipeline.vae
        self.image_encoder = self.pipeline.image_encoder
        self.scheduler = self.pipeline.scheduler

        self.vae.requires_grad_(False)
        self.image_encoder.requires_grad_(False)
        self.unet.requires_grad_(True)
        self.unet.enable_gradient_checkpointing()

        self.graph_conditioner = GraphConditioner(args)

    def encode_graph_condition(self, batch) -> torch.Tensor:
        return self.graph_conditioner(batch["graph_seq"])

    @torch.no_grad()
    def encode_rgb_to_latents(self, rgb: torch.Tensor) -> torch.Tensor:
        """Encode RGB clips [B, T, 3, H, W] in [0,1] into VAE latents."""
        device = self.unet.device
        rgb = rgb.to(device)
        bsz, num_frames, channels, height, width = rgb.shape
        flat_rgb = rgb.flatten(0, 1)
        flat_rgb = flat_rgb * 2.0 - 1.0

        needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
        if needs_upcasting:
            self.vae.to(dtype=torch.float32)
            flat_rgb = flat_rgb.to(torch.float32)
        else:
            flat_rgb = flat_rgb.to(self.vae.dtype)

        posterior = self.vae.encode(flat_rgb).latent_dist
        flat_latents = posterior.sample() * self.vae.config.scaling_factor

        if needs_upcasting:
            self.vae.to(dtype=self.unet.dtype)

        latents = flat_latents.reshape(bsz, num_frames, *flat_latents.shape[1:])
        return latents.to(self.unet.dtype)

    def forward(self, batch):
        if "latent" in batch:
            latents = batch["latent"]
        elif "rgb" in batch:
            latents = self.encode_rgb_to_latents(batch["rgb"])
        else:
            raise KeyError("Batch must contain either 'latent' or 'rgb'.")

        device = self.unet.device
        dtype = self.unet.dtype
        P_mean = 0.7
        P_std = 1.6
        noise_aug_strength = 0.0

        num_history = self.args.num_history
        latents = latents.to(device)

        current_img = latents[:, num_history:(num_history + 1)]
        bsz, num_frames = latents.shape[:2]
        current_img = current_img[:, 0]
        sigma = torch.rand([bsz, 1, 1, 1], device=device) * 0.2
        c_in = 1 / (sigma**2 + 1) ** 0.5
        current_img = c_in * (current_img + torch.randn_like(current_img) * sigma)
        condition_latent = einops.repeat(current_img, "b c h w -> b f c h w", f=num_frames)
        if self.args.his_cond_zero:
            condition_latent[:, :num_history] = 0.0

        graph_hidden = self.encode_graph_condition(batch).to(device=device, dtype=dtype)

        uncond_hidden_states = torch.zeros_like(graph_hidden)
        cond_mask = (torch.rand(graph_hidden.shape[0], device=device) > 0.05).view(-1, 1, 1, 1)
        graph_hidden = graph_hidden * cond_mask + uncond_hidden_states * (~cond_mask)

        rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device)
        sigma = (rnd_normal * P_std + P_mean).exp()
        c_skip = 1 / (sigma**2 + 1)
        c_out = -sigma / (sigma**2 + 1) ** 0.5
        c_in = 1 / (sigma**2 + 1) ** 0.5
        c_noise = (sigma.log() / 4).reshape([bsz])
        loss_weight = (sigma**2 + 1) / sigma**2
        noisy_latents = latents + torch.randn_like(latents) * sigma

        sigma_h = torch.randn([bsz, num_history, 1, 1, 1], device=device) * 0.3
        history = latents[:, :num_history]
        noisy_history = 1 / (sigma_h**2 + 1) ** 0.5 * (history + sigma_h * torch.randn_like(history))
        input_latents = torch.cat([noisy_history, c_in * noisy_latents[:, num_history:]], dim=1)
        input_latents = torch.cat([input_latents, condition_latent / self.vae.config.scaling_factor], dim=2)

        added_time_ids = self.pipeline._get_add_time_ids(
            self.args.fps,
            self.args.motion_bucket_id,
            noise_aug_strength,
            graph_hidden.dtype,
            bsz,
            1,
            False,
        ).to(device)

        model_pred = self.unet(
            input_latents,
            c_noise,
            encoder_hidden_states=graph_hidden,
            added_time_ids=added_time_ids,
            frame_level_cond=self.args.frame_level_cond,
        ).sample
        predict_x0 = c_out * model_pred + c_skip * noisy_latents
        loss = ((predict_x0[:, num_history:] - latents[:, num_history:]) ** 2 * loss_weight).mean()
        return loss, torch.tensor(0.0, device=device, dtype=dtype)