gnn_wm / Ctrl-World-Graph /graphwm /models /ctrl_world_graph.py
EndeavourDD's picture
Add files using upload-large-folder tool
da7bf91 verified
from __future__ import annotations
import torch
import torch.nn as nn
import einops
from graphwm.config_graph import GraphWMArgs
from graphwm.models.graph_encoder_pyg import GraphSpatialEncoder
from graphwm.models.graph_resampler import GraphResampler
from graphwm.models.temporal_graph_conditioner import TemporalGraphConditioner
from graphwm.original_ctrl_world import import_original_modules
class GraphConditioner(nn.Module):
"""Per-frame PyG encoder -> fixed-K graph tokens -> temporal transformer."""
def __init__(self, args: GraphWMArgs):
super().__init__()
self.spatial = GraphSpatialEncoder(
node_in_dim=args.graph_in_dim,
edge_in_dim=args.edge_in_dim,
hidden_dim=args.graph_hidden_dim,
num_layers=args.graph_num_layers,
dropout=args.graph_dropout,
backbone=args.graph_backbone,
num_heads=args.graph_num_heads,
)
self.resampler = GraphResampler(
hidden_dim=args.graph_hidden_dim,
num_tokens=args.graph_num_tokens,
num_heads=args.graph_num_heads,
dropout=args.graph_dropout,
)
self.temporal = TemporalGraphConditioner(
hidden_dim=args.graph_hidden_dim,
cond_dim=args.graph_cond_dim,
num_layers=args.graph_temporal_layers,
num_heads=args.graph_temporal_heads,
dropout=args.graph_dropout,
)
def forward(self, graph_seq):
per_frame_tokens = []
for graph_batch in graph_seq:
node_tokens = self.spatial(graph_batch)
frame_tokens = self.resampler(node_tokens, graph_batch.batch)
per_frame_tokens.append(frame_tokens)
frame_tokens = torch.stack(per_frame_tokens, dim=1)
return self.temporal(frame_tokens)
class CtrlWorldGraph(nn.Module):
"""Graph-conditioned wrapper around the original Ctrl-World backbone."""
def __init__(self, args: GraphWMArgs):
super().__init__()
self.args = args
original = import_original_modules(args.ctrl_world_root)
StableVideoDiffusionPipeline = original["StableVideoDiffusionPipeline"]
UNetSpatioTemporalConditionModel = original["UNetSpatioTemporalConditionModel"]
self.pipeline = StableVideoDiffusionPipeline.from_pretrained(args.svd_model_path)
unet = UNetSpatioTemporalConditionModel()
unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False)
self.pipeline.unet = unet
self.unet = self.pipeline.unet
self.vae = self.pipeline.vae
self.image_encoder = self.pipeline.image_encoder
self.scheduler = self.pipeline.scheduler
self.vae.requires_grad_(False)
self.image_encoder.requires_grad_(False)
self.unet.requires_grad_(True)
self.unet.enable_gradient_checkpointing()
self.graph_conditioner = GraphConditioner(args)
def encode_graph_condition(self, batch) -> torch.Tensor:
return self.graph_conditioner(batch["graph_seq"])
@torch.no_grad()
def encode_rgb_to_latents(self, rgb: torch.Tensor) -> torch.Tensor:
"""Encode RGB clips [B, T, 3, H, W] in [0,1] into VAE latents."""
device = self.unet.device
rgb = rgb.to(device)
bsz, num_frames, channels, height, width = rgb.shape
flat_rgb = rgb.flatten(0, 1)
flat_rgb = flat_rgb * 2.0 - 1.0
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
flat_rgb = flat_rgb.to(torch.float32)
else:
flat_rgb = flat_rgb.to(self.vae.dtype)
posterior = self.vae.encode(flat_rgb).latent_dist
flat_latents = posterior.sample() * self.vae.config.scaling_factor
if needs_upcasting:
self.vae.to(dtype=self.unet.dtype)
latents = flat_latents.reshape(bsz, num_frames, *flat_latents.shape[1:])
return latents.to(self.unet.dtype)
def forward(self, batch):
if "latent" in batch:
latents = batch["latent"]
elif "rgb" in batch:
latents = self.encode_rgb_to_latents(batch["rgb"])
else:
raise KeyError("Batch must contain either 'latent' or 'rgb'.")
device = self.unet.device
dtype = self.unet.dtype
P_mean = 0.7
P_std = 1.6
noise_aug_strength = 0.0
num_history = self.args.num_history
latents = latents.to(device)
current_img = latents[:, num_history:(num_history + 1)]
bsz, num_frames = latents.shape[:2]
current_img = current_img[:, 0]
sigma = torch.rand([bsz, 1, 1, 1], device=device) * 0.2
c_in = 1 / (sigma**2 + 1) ** 0.5
current_img = c_in * (current_img + torch.randn_like(current_img) * sigma)
condition_latent = einops.repeat(current_img, "b c h w -> b f c h w", f=num_frames)
if self.args.his_cond_zero:
condition_latent[:, :num_history] = 0.0
graph_hidden = self.encode_graph_condition(batch).to(device=device, dtype=dtype)
uncond_hidden_states = torch.zeros_like(graph_hidden)
cond_mask = (torch.rand(graph_hidden.shape[0], device=device) > 0.05).view(-1, 1, 1, 1)
graph_hidden = graph_hidden * cond_mask + uncond_hidden_states * (~cond_mask)
rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device)
sigma = (rnd_normal * P_std + P_mean).exp()
c_skip = 1 / (sigma**2 + 1)
c_out = -sigma / (sigma**2 + 1) ** 0.5
c_in = 1 / (sigma**2 + 1) ** 0.5
c_noise = (sigma.log() / 4).reshape([bsz])
loss_weight = (sigma**2 + 1) / sigma**2
noisy_latents = latents + torch.randn_like(latents) * sigma
sigma_h = torch.randn([bsz, num_history, 1, 1, 1], device=device) * 0.3
history = latents[:, :num_history]
noisy_history = 1 / (sigma_h**2 + 1) ** 0.5 * (history + sigma_h * torch.randn_like(history))
input_latents = torch.cat([noisy_history, c_in * noisy_latents[:, num_history:]], dim=1)
input_latents = torch.cat([input_latents, condition_latent / self.vae.config.scaling_factor], dim=2)
added_time_ids = self.pipeline._get_add_time_ids(
self.args.fps,
self.args.motion_bucket_id,
noise_aug_strength,
graph_hidden.dtype,
bsz,
1,
False,
).to(device)
model_pred = self.unet(
input_latents,
c_noise,
encoder_hidden_states=graph_hidden,
added_time_ids=added_time_ids,
frame_level_cond=self.args.frame_level_cond,
).sample
predict_x0 = c_out * model_pred + c_skip * noisy_latents
loss = ((predict_x0[:, num_history:] - latents[:, num_history:]) ** 2 * loss_weight).mean()
return loss, torch.tensor(0.0, device=device, dtype=dtype)