# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """WorldModel transformer for frame generation.""" from typing import Optional, List import math import einops as eo import torch from torch import nn, Tensor import torch.nn.functional as F from tensordict import TensorDict from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from .attn import Attn, MergedQKVAttn, CrossAttention from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm from .quantize import quantize_model from .cache import CachedDenoiseStepEmb, CachedCondHead def patch_cached_noise_conditioning(model) -> None: # Call AFTER: model.to(device="cuda", dtype=torch.bfloat16).eval() cached_denoise_step_emb = CachedDenoiseStepEmb( model.denoise_step_emb, model.config.scheduler_sigmas ) model.denoise_step_emb = cached_denoise_step_emb for blk in model.transformer.blocks: blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb) def patch_Attn_merge_qkv(model) -> None: for name, mod in list(model.named_modules()): if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): model.set_submodule(name, MergedQKVAttn(mod, model.config)) def patch_MLPFusion_split(model) -> None: for name, mod in list(model.named_modules()): if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion): model.set_submodule(name, SplitMLPFusion(mod)) def _apply_inference_patches(model) -> None: patch_cached_noise_conditioning(model) patch_Attn_merge_qkv(model) patch_MLPFusion_split(model) class CFG(nn.Module): def __init__(self, d_model: int, dropout: float): super().__init__() self.dropout = dropout self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) def forward( self, x: torch.Tensor, is_conditioned: Optional[bool] = None ) -> torch.Tensor: """ x: [B, L, D] is_conditioned: - None: training-style random dropout - bool: whole batch conditioned / unconditioned at sampling """ B, L, _ = x.shape null = self.null_emb.expand(B, L, -1) # training-style dropout OR unspecified if self.training or is_conditioned is None: if self.dropout == 0.0: return x drop = torch.rand(B, 1, 1, device=x.device) < self.dropout # [B,1,1] return torch.where(drop, null, x) # sampling-time switch return x if is_conditioned else null class ControllerInputEmbedding(nn.Module): """Embeds controller inputs (mouse + buttons) into model dimension.""" def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): super().__init__() self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) # mouse velocity (x,y) + scroll sign def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): assert len(mouse.shape) == 3 x = torch.cat((mouse, button, scroll), dim=-1) return self.mlp(x) class MLPFusion(nn.Module): """Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond]).""" def __init__(self, d_model: int): super().__init__() self.mlp = MLP(2 * d_model, d_model, d_model) def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, _, D = x.shape L = cond.shape[1] Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) # each [D, D] x = x.view(B, L, -1, D) h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze( 2 ) # broadcast, no repeat/cat h = F.silu(h) y = F.linear(h, self.mlp.fc2.weight) return y.flatten(1, 2) class SplitMLPFusion(nn.Module): """Packed MLPFusion -> split linears (no cat, quant-friendly).""" def __init__(self, src: MLPFusion): super().__init__() D = src.mlp.fc2.in_features dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt) self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt) self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt) with torch.no_grad(): Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1) self.fc1_x.weight.copy_(Wx) self.fc1_c.weight.copy_(Wc) self.fc2.weight.copy_(src.mlp.fc2.weight) self.train(src.training) def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, _, D = x.shape L = cond.shape[1] x = x.reshape(B, L, -1, D) return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( 1, 2 ) class CondHead(nn.Module): """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" n_cond = 6 def __init__(self, d_model: int, noise_conditioning: str = "wan"): super().__init__() self.bias_in = ( nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None ) self.cond_proj = nn.ModuleList( [nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)] ) def forward(self, cond): cond = cond + self.bias_in if self.bias_in is not None else cond h = F.silu(cond) return tuple(p(h) for p in self.cond_proj) class WorldDiTBlock(nn.Module): """Single transformer block with self-attention, optional cross-attention, and MLP.""" def __init__( self, d_model: int, n_heads: int, mlp_ratio: int, layer_idx: int, prompt_conditioning: Optional[str], prompt_conditioning_period: int, prompt_embedding_dim: int, ctrl_conditioning_period: int, noise_conditioning: str, config, ): super().__init__() self.config = config self.attn = Attn(config, layer_idx) self.mlp = MLP(d_model, d_model * mlp_ratio, d_model) self.cond_head = CondHead(d_model, noise_conditioning) do_prompt_cond = ( prompt_conditioning is not None and layer_idx % prompt_conditioning_period == 0 ) self.prompt_cross_attn = ( CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None ) do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0 self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): """ 0) Causal Frame Attention 1) Frame->CTX Cross Attention 2) MLP """ s0, b0, g0, s1, b1, g1 = self.cond_head(cond) # Self / Causal Attention residual = x x = ada_rmsnorm(x, s0, b0) x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) x = ada_gate(x, g0) + residual # Cross Attention Prompt Conditioning if self.prompt_cross_attn is not None: x = ( self.prompt_cross_attn( rms_norm(x), context=rms_norm(ctx["prompt_emb"]), context_pad_mask=ctx["prompt_pad_mask"], ) + x ) # MLPFusion Controller Conditioning if self.ctrl_mlpfusion is not None: x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x # MLP x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x return x, v class WorldDiT(nn.Module): """Stack of WorldDiTBlocks with shared parameters.""" def __init__(self, config): super().__init__() self.config = config self.blocks = nn.ModuleList( [ WorldDiTBlock( d_model=config.d_model, n_heads=config.n_heads, mlp_ratio=config.mlp_ratio, layer_idx=idx, prompt_conditioning=config.prompt_conditioning, prompt_conditioning_period=config.prompt_conditioning_period, prompt_embedding_dim=config.prompt_embedding_dim, ctrl_conditioning_period=config.ctrl_conditioning_period, noise_conditioning=config.noise_conditioning, config=config, ) for idx in range(config.n_layers) ] ) if config.noise_conditioning in ("dit_air", "wan"): ref_proj = self.blocks[0].cond_head.cond_proj for blk in self.blocks[1:]: for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): blk_mod.weight = ref_mod.weight # Shared RoPE buffers ref_rope = self.blocks[0].attn.rope for blk in self.blocks[1:]: blk.attn.rope = ref_rope def forward(self, x, pos_ids, cond, ctx, kv_cache=None): v = None for i, block in enumerate(self.blocks): x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) return x class WorldModel(ModelMixin, ConfigMixin): """ WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. Denoises a frame given: - All previous frames (via KV cache) - The prompt embedding - The controller input embedding - The current noise level """ _supports_gradient_checkpointing = False _keep_in_fp32_modules = ["denoise_step_emb", "rope"] @register_to_config def __init__( self, # Model architecture d_model: int = 2560, n_heads: int = 40, n_kv_heads: Optional[int] = 20, n_layers: int = 22, mlp_ratio: int = 5, channels: int = 16, height: int = 16, width: int = 16, patch: tuple = (2, 2), tokens_per_frame: int = 256, n_frames: int = 512, local_window: int = 16, global_window: int = 128, global_attn_period: int = 4, global_pinned_dilation: int = 8, global_attn_offset: int = -1, value_residual: bool = False, gated_attn: bool = True, n_buttons: int = 256, ctrl_conditioning: Optional[str] = "mlp_fusion", ctrl_conditioning_period: int = 3, ctrl_cond_dropout: float = 0.0, prompt_conditioning: Optional[str] = "cross_attention", prompt_conditioning_period: int = 3, prompt_embedding_dim: int = 2048, prompt_cond_dropout: float = 0.0, noise_conditioning: str = "wan", scheduler_sigmas: Optional[List[float]] = [ 1.0, 0.9483006596565247, 0.8379597067832947, 0.0, ], base_fps: int = 60, causal: bool = True, mlp_gradient_checkpointing: bool = True, block_gradient_checkpointing: bool = True, rope_impl: str = "ortho", ): super().__init__() self.denoise_step_emb = NoiseConditioner(d_model) self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) if self.config.ctrl_conditioning is not None: self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) if self.config.prompt_conditioning is not None: self.prompt_cfg = CFG( self.config.prompt_embedding_dim, self.config.prompt_cond_dropout ) self.transformer = WorldDiT(self.config) self.patch = tuple(patch) C, D = channels, d_model self.patchify = nn.Conv2d( C, D, kernel_size=self.patch, stride=self.patch, bias=False ) self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True) self.out_norm = AdaLN(d_model) # Cached 1-frame pos_ids (buffers + cached TensorDict view) T = tokens_per_frame idx = torch.arange(T, dtype=torch.long) self.register_buffer( "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False ) self.register_buffer( "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False ) self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) def forward( self, x: Tensor, sigma: Tensor, frame_timestamp: Tensor, prompt_emb: Optional[Tensor] = None, prompt_pad_mask: Optional[Tensor] = None, mouse: Optional[Tensor] = None, button: Optional[Tensor] = None, scroll: Optional[Tensor] = None, kv_cache=None, ): """ Args: x: [B, N, C, H, W] - latent frames sigma: [B, N] - noise levels frame_timestamp: [B, N] - frame indices prompt_emb: [B, P, D] - prompt embeddings prompt_pad_mask: [B, P] - padding mask for prompts mouse: [B, N, 2] - mouse velocity button: [B, N, n_buttons] - button states scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1) kv_cache: StaticKVCache instance ctrl_cond: whether to apply controller conditioning (inference only) prompt_cond: whether to apply prompt conditioning (inference only) """ B, N, C, H, W = x.shape ph, pw = self.patch assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" Hp, Wp = H // ph, W // pw torch._assert( Hp * Wp == self.config.tokens_per_frame, f"{Hp} * {Wp} != {self.config.tokens_per_frame}", ) torch._assert( B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" ) self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) pos_ids = TensorDict( { "t_pos": self._t_pos_1f[None], "y_pos": self._y_pos_1f[None], "x_pos": self._x_pos_1f[None], }, batch_size=[1, self._t_pos_1f.numel()], ) cond = self.denoise_step_emb(sigma) # [B, N, d] assert button is not None ctx = { "ctrl_emb": self.ctrl_emb(mouse, button, scroll), "prompt_emb": prompt_emb, "prompt_pad_mask": prompt_pad_mask, } D = self.unpatchify.in_features x = self.patchify(x.reshape(B * N, C, H, W)) x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") x = self.transformer(x, pos_ids, cond, ctx, kv_cache) x = F.silu(self.out_norm(x, cond)) x = eo.rearrange( self.unpatchify(x), "b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)", n=N, hp=Hp, wp=Wp, ph=ph, pw=pw, ) return x def quantize(self, quant_type: str): quantize_model(self, quant_type) def apply_inference_patches(self): _apply_inference_patches(self)