from __future__ import annotations import torch import torch.nn as nn import einops from graphwm.config_graph import GraphWMArgs from graphwm.models.graph_encoder_pyg import GraphSpatialEncoder from graphwm.models.graph_resampler import GraphResampler from graphwm.models.temporal_graph_conditioner import TemporalGraphConditioner from graphwm.original_ctrl_world import import_original_modules class GraphConditioner(nn.Module): """Per-frame PyG encoder -> fixed-K graph tokens -> temporal transformer.""" def __init__(self, args: GraphWMArgs): super().__init__() self.spatial = GraphSpatialEncoder( node_in_dim=args.graph_in_dim, edge_in_dim=args.edge_in_dim, hidden_dim=args.graph_hidden_dim, num_layers=args.graph_num_layers, dropout=args.graph_dropout, backbone=args.graph_backbone, num_heads=args.graph_num_heads, ) self.resampler = GraphResampler( hidden_dim=args.graph_hidden_dim, num_tokens=args.graph_num_tokens, num_heads=args.graph_num_heads, dropout=args.graph_dropout, ) self.temporal = TemporalGraphConditioner( hidden_dim=args.graph_hidden_dim, cond_dim=args.graph_cond_dim, num_layers=args.graph_temporal_layers, num_heads=args.graph_temporal_heads, dropout=args.graph_dropout, ) def forward(self, graph_seq): per_frame_tokens = [] for graph_batch in graph_seq: node_tokens = self.spatial(graph_batch) frame_tokens = self.resampler(node_tokens, graph_batch.batch) per_frame_tokens.append(frame_tokens) frame_tokens = torch.stack(per_frame_tokens, dim=1) return self.temporal(frame_tokens) class CtrlWorldGraph(nn.Module): """Graph-conditioned wrapper around the original Ctrl-World backbone.""" def __init__(self, args: GraphWMArgs): super().__init__() self.args = args original = import_original_modules(args.ctrl_world_root) StableVideoDiffusionPipeline = original["StableVideoDiffusionPipeline"] UNetSpatioTemporalConditionModel = original["UNetSpatioTemporalConditionModel"] self.pipeline = StableVideoDiffusionPipeline.from_pretrained(args.svd_model_path) unet = UNetSpatioTemporalConditionModel() unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False) self.pipeline.unet = unet self.unet = self.pipeline.unet self.vae = self.pipeline.vae self.image_encoder = self.pipeline.image_encoder self.scheduler = self.pipeline.scheduler self.vae.requires_grad_(False) self.image_encoder.requires_grad_(False) self.unet.requires_grad_(True) self.unet.enable_gradient_checkpointing() self.graph_conditioner = GraphConditioner(args) def encode_graph_condition(self, batch) -> torch.Tensor: return self.graph_conditioner(batch["graph_seq"]) @torch.no_grad() def encode_rgb_to_latents(self, rgb: torch.Tensor) -> torch.Tensor: """Encode RGB clips [B, T, 3, H, W] in [0,1] into VAE latents.""" device = self.unet.device rgb = rgb.to(device) bsz, num_frames, channels, height, width = rgb.shape flat_rgb = rgb.flatten(0, 1) flat_rgb = flat_rgb * 2.0 - 1.0 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.vae.to(dtype=torch.float32) flat_rgb = flat_rgb.to(torch.float32) else: flat_rgb = flat_rgb.to(self.vae.dtype) posterior = self.vae.encode(flat_rgb).latent_dist flat_latents = posterior.sample() * self.vae.config.scaling_factor if needs_upcasting: self.vae.to(dtype=self.unet.dtype) latents = flat_latents.reshape(bsz, num_frames, *flat_latents.shape[1:]) return latents.to(self.unet.dtype) def forward(self, batch): if "latent" in batch: latents = batch["latent"] elif "rgb" in batch: latents = self.encode_rgb_to_latents(batch["rgb"]) else: raise KeyError("Batch must contain either 'latent' or 'rgb'.") device = self.unet.device dtype = self.unet.dtype P_mean = 0.7 P_std = 1.6 noise_aug_strength = 0.0 num_history = self.args.num_history latents = latents.to(device) current_img = latents[:, num_history:(num_history + 1)] bsz, num_frames = latents.shape[:2] current_img = current_img[:, 0] sigma = torch.rand([bsz, 1, 1, 1], device=device) * 0.2 c_in = 1 / (sigma**2 + 1) ** 0.5 current_img = c_in * (current_img + torch.randn_like(current_img) * sigma) condition_latent = einops.repeat(current_img, "b c h w -> b f c h w", f=num_frames) if self.args.his_cond_zero: condition_latent[:, :num_history] = 0.0 graph_hidden = self.encode_graph_condition(batch).to(device=device, dtype=dtype) uncond_hidden_states = torch.zeros_like(graph_hidden) cond_mask = (torch.rand(graph_hidden.shape[0], device=device) > 0.05).view(-1, 1, 1, 1) graph_hidden = graph_hidden * cond_mask + uncond_hidden_states * (~cond_mask) rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device) sigma = (rnd_normal * P_std + P_mean).exp() c_skip = 1 / (sigma**2 + 1) c_out = -sigma / (sigma**2 + 1) ** 0.5 c_in = 1 / (sigma**2 + 1) ** 0.5 c_noise = (sigma.log() / 4).reshape([bsz]) loss_weight = (sigma**2 + 1) / sigma**2 noisy_latents = latents + torch.randn_like(latents) * sigma sigma_h = torch.randn([bsz, num_history, 1, 1, 1], device=device) * 0.3 history = latents[:, :num_history] noisy_history = 1 / (sigma_h**2 + 1) ** 0.5 * (history + sigma_h * torch.randn_like(history)) input_latents = torch.cat([noisy_history, c_in * noisy_latents[:, num_history:]], dim=1) input_latents = torch.cat([input_latents, condition_latent / self.vae.config.scaling_factor], dim=2) added_time_ids = self.pipeline._get_add_time_ids( self.args.fps, self.args.motion_bucket_id, noise_aug_strength, graph_hidden.dtype, bsz, 1, False, ).to(device) model_pred = self.unet( input_latents, c_noise, encoder_hidden_states=graph_hidden, added_time_ids=added_time_ids, frame_level_cond=self.args.frame_level_cond, ).sample predict_x0 = c_out * model_pred + c_skip * noisy_latents loss = ((predict_x0[:, num_history:] - latents[:, num_history:]) ** 2 * loss_weight).mean() return loss, torch.tensor(0.0, device=device, dtype=dtype)