Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import einops | |
| from graphwm.config_graph import GraphWMArgs | |
| from graphwm.models.graph_encoder_pyg import GraphSpatialEncoder | |
| from graphwm.models.graph_resampler import GraphResampler | |
| from graphwm.models.temporal_graph_conditioner import TemporalGraphConditioner | |
| from graphwm.original_ctrl_world import import_original_modules | |
| class GraphConditioner(nn.Module): | |
| """Per-frame PyG encoder -> fixed-K graph tokens -> temporal transformer.""" | |
| def __init__(self, args: GraphWMArgs): | |
| super().__init__() | |
| self.spatial = GraphSpatialEncoder( | |
| node_in_dim=args.graph_in_dim, | |
| edge_in_dim=args.edge_in_dim, | |
| hidden_dim=args.graph_hidden_dim, | |
| num_layers=args.graph_num_layers, | |
| dropout=args.graph_dropout, | |
| backbone=args.graph_backbone, | |
| num_heads=args.graph_num_heads, | |
| ) | |
| self.resampler = GraphResampler( | |
| hidden_dim=args.graph_hidden_dim, | |
| num_tokens=args.graph_num_tokens, | |
| num_heads=args.graph_num_heads, | |
| dropout=args.graph_dropout, | |
| ) | |
| self.temporal = TemporalGraphConditioner( | |
| hidden_dim=args.graph_hidden_dim, | |
| cond_dim=args.graph_cond_dim, | |
| num_layers=args.graph_temporal_layers, | |
| num_heads=args.graph_temporal_heads, | |
| dropout=args.graph_dropout, | |
| ) | |
| def forward(self, graph_seq): | |
| per_frame_tokens = [] | |
| for graph_batch in graph_seq: | |
| node_tokens = self.spatial(graph_batch) | |
| frame_tokens = self.resampler(node_tokens, graph_batch.batch) | |
| per_frame_tokens.append(frame_tokens) | |
| frame_tokens = torch.stack(per_frame_tokens, dim=1) | |
| return self.temporal(frame_tokens) | |
| class CtrlWorldGraph(nn.Module): | |
| """Graph-conditioned wrapper around the original Ctrl-World backbone.""" | |
| def __init__(self, args: GraphWMArgs): | |
| super().__init__() | |
| self.args = args | |
| original = import_original_modules(args.ctrl_world_root) | |
| StableVideoDiffusionPipeline = original["StableVideoDiffusionPipeline"] | |
| UNetSpatioTemporalConditionModel = original["UNetSpatioTemporalConditionModel"] | |
| self.pipeline = StableVideoDiffusionPipeline.from_pretrained(args.svd_model_path) | |
| unet = UNetSpatioTemporalConditionModel() | |
| unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False) | |
| self.pipeline.unet = unet | |
| self.unet = self.pipeline.unet | |
| self.vae = self.pipeline.vae | |
| self.image_encoder = self.pipeline.image_encoder | |
| self.scheduler = self.pipeline.scheduler | |
| self.vae.requires_grad_(False) | |
| self.image_encoder.requires_grad_(False) | |
| self.unet.requires_grad_(True) | |
| self.unet.enable_gradient_checkpointing() | |
| self.graph_conditioner = GraphConditioner(args) | |
| def encode_graph_condition(self, batch) -> torch.Tensor: | |
| return self.graph_conditioner(batch["graph_seq"]) | |
| def encode_rgb_to_latents(self, rgb: torch.Tensor) -> torch.Tensor: | |
| """Encode RGB clips [B, T, 3, H, W] in [0,1] into VAE latents.""" | |
| device = self.unet.device | |
| rgb = rgb.to(device) | |
| bsz, num_frames, channels, height, width = rgb.shape | |
| flat_rgb = rgb.flatten(0, 1) | |
| flat_rgb = flat_rgb * 2.0 - 1.0 | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float32) | |
| flat_rgb = flat_rgb.to(torch.float32) | |
| else: | |
| flat_rgb = flat_rgb.to(self.vae.dtype) | |
| posterior = self.vae.encode(flat_rgb).latent_dist | |
| flat_latents = posterior.sample() * self.vae.config.scaling_factor | |
| if needs_upcasting: | |
| self.vae.to(dtype=self.unet.dtype) | |
| latents = flat_latents.reshape(bsz, num_frames, *flat_latents.shape[1:]) | |
| return latents.to(self.unet.dtype) | |
| def forward(self, batch): | |
| if "latent" in batch: | |
| latents = batch["latent"] | |
| elif "rgb" in batch: | |
| latents = self.encode_rgb_to_latents(batch["rgb"]) | |
| else: | |
| raise KeyError("Batch must contain either 'latent' or 'rgb'.") | |
| device = self.unet.device | |
| dtype = self.unet.dtype | |
| P_mean = 0.7 | |
| P_std = 1.6 | |
| noise_aug_strength = 0.0 | |
| num_history = self.args.num_history | |
| latents = latents.to(device) | |
| current_img = latents[:, num_history:(num_history + 1)] | |
| bsz, num_frames = latents.shape[:2] | |
| current_img = current_img[:, 0] | |
| sigma = torch.rand([bsz, 1, 1, 1], device=device) * 0.2 | |
| c_in = 1 / (sigma**2 + 1) ** 0.5 | |
| current_img = c_in * (current_img + torch.randn_like(current_img) * sigma) | |
| condition_latent = einops.repeat(current_img, "b c h w -> b f c h w", f=num_frames) | |
| if self.args.his_cond_zero: | |
| condition_latent[:, :num_history] = 0.0 | |
| graph_hidden = self.encode_graph_condition(batch).to(device=device, dtype=dtype) | |
| uncond_hidden_states = torch.zeros_like(graph_hidden) | |
| cond_mask = (torch.rand(graph_hidden.shape[0], device=device) > 0.05).view(-1, 1, 1, 1) | |
| graph_hidden = graph_hidden * cond_mask + uncond_hidden_states * (~cond_mask) | |
| rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device) | |
| sigma = (rnd_normal * P_std + P_mean).exp() | |
| c_skip = 1 / (sigma**2 + 1) | |
| c_out = -sigma / (sigma**2 + 1) ** 0.5 | |
| c_in = 1 / (sigma**2 + 1) ** 0.5 | |
| c_noise = (sigma.log() / 4).reshape([bsz]) | |
| loss_weight = (sigma**2 + 1) / sigma**2 | |
| noisy_latents = latents + torch.randn_like(latents) * sigma | |
| sigma_h = torch.randn([bsz, num_history, 1, 1, 1], device=device) * 0.3 | |
| history = latents[:, :num_history] | |
| noisy_history = 1 / (sigma_h**2 + 1) ** 0.5 * (history + sigma_h * torch.randn_like(history)) | |
| input_latents = torch.cat([noisy_history, c_in * noisy_latents[:, num_history:]], dim=1) | |
| input_latents = torch.cat([input_latents, condition_latent / self.vae.config.scaling_factor], dim=2) | |
| added_time_ids = self.pipeline._get_add_time_ids( | |
| self.args.fps, | |
| self.args.motion_bucket_id, | |
| noise_aug_strength, | |
| graph_hidden.dtype, | |
| bsz, | |
| 1, | |
| False, | |
| ).to(device) | |
| model_pred = self.unet( | |
| input_latents, | |
| c_noise, | |
| encoder_hidden_states=graph_hidden, | |
| added_time_ids=added_time_ids, | |
| frame_level_cond=self.args.frame_level_cond, | |
| ).sample | |
| predict_x0 = c_out * model_pred + c_skip * noisy_latents | |
| loss = ((predict_x0[:, num_history:] - latents[:, num_history:]) ** 2 * loss_weight).mean() | |
| return loss, torch.tensor(0.0, device=device, dtype=dtype) | |