# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """WorldModel transformer for frame generation. Single-file model containing all building blocks: nn primitives, attention, RoPE, quantization, inference caching, and the top-level WorldModel. """ import warnings import einops as eo import torch from torch import nn, Tensor import torch.nn.functional as F from tensordict import TensorDict from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin try: from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa HAS_FBGEMM = True except ImportError: HAS_FBGEMM = False # --------------------------------------------------------------------------- # NN primitives # --------------------------------------------------------------------------- class NoCastModule(torch.nn.Module): """Module that prevents dtype casting during .to() calls.""" def _apply(self, fn): def keep_dtype(t): old_dtype = t.dtype out = fn(t) if out.dtype is not old_dtype: warnings.warn( f"{self.__class__.__name__}: requested dtype cast ignored; " f"keeping {old_dtype}.", stacklevel=3, ) out = out.to(dtype=old_dtype) return out return super()._apply(keep_dtype) def to(self, *args, **kwargs): warn_cast = False if args and isinstance(args[0], torch.Tensor): ref, *rest = args args = (ref.device, *rest) base = next(self.parameters(), None) or next(self.buffers(), None) if base is not None and ref.dtype is not base.dtype: warn_cast = True if kwargs.pop("dtype", None) is not None: warn_cast = True args = tuple(a for a in args if not isinstance(a, torch.dtype)) if warn_cast: warnings.warn( f"{self.__class__.__name__}.to: requested dtype cast ignored; " "keeping existing dtypes.", stacklevel=2, ) return super().to(*args, **kwargs) def rms_norm(x: torch.Tensor) -> torch.Tensor: """Root mean square layer normalization.""" return F.rms_norm(x, (x.size(-1),)) class MLP(nn.Module): """Simple MLP with SiLU activation.""" def __init__(self, dim_in, dim_middle, dim_out): super().__init__() self.fc1 = nn.Linear(dim_in, dim_middle, bias=False) self.fc2 = nn.Linear(dim_middle, dim_out, bias=False) def forward(self, x): return self.fc2(F.silu(self.fc1(x))) class AdaLN(nn.Module): """Adaptive Layer Normalization.""" def __init__(self, dim): super().__init__() self.fc = nn.Linear(dim, 2 * dim, bias=False) def forward(self, x, cond): b, n, d = cond.shape _, nm, _ = x.shape m = nm // n y = F.silu(cond) ab = self.fc(y) # [b, n, 2d] ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d] ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d] ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d] a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each x = rms_norm(x) * (1 + a) + b_ return x def ada_rmsnorm(x, scale, bias): """Adaptive RMS normalization with scale and bias.""" x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1)) y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2) return eo.rearrange(y4, "b n m d -> b (n m) d") def ada_gate(x, gate): """Apply gating to x with per-frame gates.""" x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1)) return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d") class NoiseConditioner(NoCastModule): """Sigma -> logSNR -> Fourier Features -> Dense embedding.""" def __init__(self, dim, fourier_dim=512, base=10_000.0): super().__init__() assert fourier_dim % 2 == 0 half = fourier_dim // 2 self.freq = nn.Buffer( torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32), persistent=False, ) self.mlp = MLP(fourier_dim, dim * 4, dim) def forward(self, s, eps=torch.finfo(torch.float32).eps): assert self.freq.dtype == torch.float32 orig_dtype, shape = s.dtype, s.shape with torch.autocast("cuda", enabled=False): s = s.reshape(-1).float() s = s * 1000 phase = s[:, None] * self.freq[None, :] emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1) emb = emb * 2**0.5 emb = self.mlp(emb) return emb.to(orig_dtype).view(*shape, -1) # --------------------------------------------------------------------------- # Attention # --------------------------------------------------------------------------- class OrthoRoPEAngles(NoCastModule): """Computes RoPE angles on the fly each forward pass.""" def __init__(self, config): super().__init__() self.config = config d_head = config.d_model // config.n_heads torch._assert(d_head % 8 == 0, "d_head must be divisible by 8") d_xy, d_t = d_head // 8, d_head // 4 nyq = float(getattr(config, "rope_nyquist_frac", 0.8)) max_freq = min(self.config.height, self.config.width) * nyq n = (d_xy + 1) // 2 xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy] theta = float(getattr(config, "rope_theta", 10000.0)) inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t)) inv_t = inv_t.repeat_interleave(2) self.register_buffer("xy", xy, persistent=False) self.register_buffer("inv_t", inv_t, persistent=False) @torch.autocast("cuda", enabled=False) def forward(self, pos_ids): if not torch.compiler.is_compiling(): torch._assert( (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width), f"pos_ids out of bounds, {self.config.height}, {self.config.width}" ) x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0 y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0 t = pos_ids["t_pos"].float() freqs = torch.cat( (x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t), dim=-1, ) return freqs.cos()[:, None], freqs.sin()[:, None] class OrthoRoPE(NoCastModule): """Applies precomputed RoPE angles to input tensors.""" def __init__(self, config): super().__init__() self.config = config assert not getattr(self.config, "has_audio", False) @torch.autocast("cuda", enabled=False) def forward(self, x, rope_angles): cos, sin = rope_angles x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1) y0 = x0 * cos - x1 * sin y1 = x1 * cos + x0 * sin return torch.cat((y0, y1), dim=-1).type_as(x) class Attn(nn.Module): """Self-attention with RoPE and optional GQA, value residual, and gated attention.""" def __init__(self, config, layer_idx): super().__init__() self.config = config self.layer_idx = layer_idx self.value_residual = getattr(config, "value_residual", False) if self.value_residual: self.v_lamb = nn.Parameter(torch.tensor(0.5)) self.n_heads = config.n_heads self.n_kv_heads = getattr(config, "n_kv_heads", None) or config.n_heads self.d_head = config.d_model // self.n_heads assert config.d_model % self.n_heads == 0 self.enable_gqa = self.n_heads != self.n_kv_heads self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False) self.k_proj = nn.Linear( config.d_model, self.n_kv_heads * self.d_head, bias=False ) self.v_proj = nn.Linear( config.d_model, self.n_kv_heads * self.d_head, bias=False ) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.rope = OrthoRoPE(config) self.gated_attn = getattr(config, "gated_attn", False) if self.gated_attn: self.gate_proj = nn.Linear( self.n_heads, self.n_heads, bias=False ) nn.init.zeros_(self.gate_proj.weight) def forward(self, x, pos_ids, rope_angles, v1, kv_cache): from torch.nn.attention.flex_attention import flex_attention q = eo.rearrange( self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head ) k = eo.rearrange( self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head ) v = eo.rearrange( self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head ) if self.value_residual: v1 = v if v1 is None else v1 v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) if self.gated_attn: gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) y = y * gates.permute(0, 2, 1).unsqueeze(-1) y = eo.rearrange(y, "b h t d -> b t (h d)") y = self.out_proj(y) return y, v1 class MergedQKVAttn(Attn): def __init__(self, src: Attn, config): super().__init__(config, src.layer_idx) self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype) self.load_state_dict( src.state_dict(), strict=False ) self.train(src.training) self.q_out = self.n_heads * self.d_head self.kv_out = self.n_kv_heads * self.d_head self.qkv_proj = nn.Linear( self.q_proj.in_features, self.q_out + 2 * self.kv_out, bias=False, device=self.q_proj.weight.device, dtype=self.q_proj.weight.dtype, ) with torch.no_grad(): self.qkv_proj.weight.copy_( torch.cat( [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0 ) ) del self.q_proj, self.k_proj, self.v_proj def forward(self, x, pos_ids, rope_angles, v1, kv_cache): from torch.nn.attention.flex_attention import flex_attention q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1) B, T = x.shape[:2] q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2) k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) if self.value_residual: v1 = v if v1 is None else v1 v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) if self.gated_attn: gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) y = y * gates.permute(0, 2, 1).unsqueeze(-1) y = y.transpose(1, 2).reshape(B, T, -1) y = self.out_proj(y) return y, v1 class CrossAttention(nn.Module): """Cross-attention for prompt conditioning.""" def __init__(self, config, context_dim=None): super().__init__() assert config.d_model % config.n_heads == 0 self.d_head = config.d_model // config.n_heads self.inner_dim = context_dim or config.d_model assert self.inner_dim % self.d_head == 0 self.n_heads = self.inner_dim // self.d_head self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False) self.k_proj = nn.Linear( context_dim or config.d_model, self.inner_dim, bias=False ) self.v_proj = nn.Linear( context_dim or config.d_model, self.inner_dim, bias=False ) self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False) self.out_proj.weight.detach().zero_() def forward(self, x, context, context_pad_mask=None): from torch.nn.attention.flex_attention import flex_attention q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads) k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads) v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads) q, k = rms_norm(q), rms_norm(k) out = flex_attention(q, k, v) out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1) return self.out_proj(out) # --------------------------------------------------------------------------- # Inference caching # --------------------------------------------------------------------------- def _bf16_u16(x: Tensor) -> Tensor: return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF class CachedDenoiseStepEmb(nn.Module): """bf16 sigma -> bf16 embedding via 64k LUT.""" def __init__(self, base: nn.Module, sigmas: list[float]): super().__init__() device = next(base.parameters()).device levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) bits = _bf16_u16(levels) if torch.unique(bits).numel() != bits.numel(): raise ValueError( "scheduler_sigmas collide in bf16; caching would be ambiguous" ) with torch.no_grad(): table = ( base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() ) lut = torch.full((65536,), -1, device=device, dtype=torch.int32) lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) self.register_buffer("table", table, persistent=False) self.register_buffer("lut", lut, persistent=False) self.register_buffer( "oob", torch.tensor(bits.numel(), device=device, dtype=torch.int32), persistent=False, ) def forward(self, sigma: Tensor) -> Tensor: if sigma.dtype is not torch.bfloat16: raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") idx = self.lut[_bf16_u16(sigma)] idx = torch.where(idx >= 0, idx, self.oob) return self.table[idx.to(torch.int64)] class CachedCondHead(nn.Module): """bf16 cond -> cached conditioning; invalid cond => OOB index error.""" def __init__( self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8 ): super().__init__() table = cached_denoise_step_emb.table S, D = table.shape with torch.no_grad(): emb = table[:, None, :] cache = ( torch.stack([t.squeeze(1) for t in base(emb)], 0) .to(torch.bfloat16) .contiguous() ) key_dim = None for d in range(min(D, max_key_dims)): b = _bf16_u16(table[:, d]) if torch.unique(b).numel() == S: key_dim = d key_bits = b break if key_dim is None: raise ValueError( "Could not find a unique bf16 key dim for cond->sigma mapping" ) lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) self.key_dim = int(key_dim) self.register_buffer("cache", cache, persistent=False) self.register_buffer("lut", lut, persistent=False) self.register_buffer( "oob", torch.tensor(S, device=table.device, dtype=torch.int32), persistent=False, ) def forward(self, cond: Tensor): if cond.dtype is not torch.bfloat16: raise RuntimeError("CachedCondHead expects cond bf16") idx = self.lut[_bf16_u16(cond[..., self.key_dim])] idx = torch.where(idx >= 0, idx, self.oob) g = self.cache[:, idx.to(torch.int64)] return tuple(g.unbind(0)) # --------------------------------------------------------------------------- # Quantization # --------------------------------------------------------------------------- QUANTS = [None] try: from flashinfer import nvfp4_quantize, mm_fp4, SfLayout QUANTS.append("nvfp4") except ImportError: pass @torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) def fp4_linear( a_bf16: torch.Tensor, b_fp4_T: torch.Tensor, a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor, ) -> torch.Tensor: a_fp4, a_sf = nvfp4_quantize( a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False, ) return mm_fp4( a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass" ) @fp4_linear.register_fake def _fp4_linear_fake( a_bf16: torch.Tensor, b_fp4_T: torch.Tensor, a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor, ) -> torch.Tensor: return torch.empty( (a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16 ) class FP4Linear(nn.Module): """FP4 Linear layer using FlashInfer's NVFP4 quantization.""" def __init__(self, lin: nn.Linear): super().__init__() self.in_features = lin.in_features self.out_features = lin.out_features assert self.in_features % 32 == 0 and self.out_features % 32 == 0 self.weight = nn.Parameter(lin.weight.detach().clone()) self._weight_fp4_T = None self._weight_scales_T = None self._alpha = None self._dummy_scale = None self._weight_global_sf = None with torch.no_grad(): self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32) weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous() weight_amax = weight_bf16.float().abs().nan_to_num().max() self._weight_global_sf = (1.0) / weight_amax self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale) w_fp4, w_sf = nvfp4_quantize( weight_bf16, self._weight_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False, ) self._weight_fp4_T = w_fp4.t() self._weight_scales_T = w_sf.t() assert self.weight.is_cuda lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16) fp4_linear(lazy_x, self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha) def forward(self, x: torch.Tensor) -> torch.Tensor: x_flat = x.reshape(-1, x.shape[-1]) y = fp4_linear( x_flat.to(torch.bfloat16).contiguous(), self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha, ) return y.reshape(x.shape[:-1] + (-1,)) class FP8W8A8Linear(nn.Module): __constants__ = ("in_features", "out_features") def __init__(self, lin: nn.Linear): super().__init__() self.in_features, self.out_features = lin.in_features, lin.out_features f8 = torch.float8_e4m3fn inv = 1.0 / float(torch.finfo(f8).max) self._inv = inv w = lin.weight.detach() ws = (w.abs().amax() * inv).clamp_min(1e-8).float() wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() self.register_buffer("wT", wf8.t()) self.register_buffer("ws", ws) if lin.bias is None: self.bias = None else: self.register_buffer("bias", lin.bias.detach().to(torch.float16)) def forward(self, x: torch.Tensor) -> torch.Tensor: s = x.shape x2 = x.reshape(-1, s[-1]) xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous() y = torch._scaled_mm( xf8, self.wT, xs, self.ws, bias=self.bias, out_dtype=torch.float16, use_fast_accum=True, ) return y.reshape(*s[:-1], self.out_features).to(x.dtype) class FP8Linear(nn.Module): def __init__(self, lin: nn.Linear): super().__init__() self.in_features, self.out_features = lin.in_features, lin.out_features self.bias = ( nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn)) if lin.bias is not None else None ) w_amax = lin.weight.data.abs().amax() w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn) self.register_buffer("w_amax", w_amax) self.register_buffer("weightT", w.t()) self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32) def forward(self, x: torch.Tensor) -> torch.Tensor: x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous() result = torch._scaled_mm( x_fp8, self.weightT, bias=self.bias, scale_a=self.dummy_scale, scale_b=self.w_amax, out_dtype=torch.bfloat16, use_fast_accum=True, ) return result.reshape(x.shape[:-1] + (-1,)) def quantize_model(model: nn.Module, quant: str): if quant is None: return model def eligible(m: nn.Module) -> bool: w = getattr(m, "weight", None) if not isinstance(m, nn.Linear): return False if getattr(w, "dtype", None) != torch.bfloat16: return False o, k = w.shape return (o % 32 == 0) and (k % 32 == 0) new_linear = {"w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear}[quant] for name, child in model.named_children(): setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(child, quant) return model # --------------------------------------------------------------------------- # Inference patches # --------------------------------------------------------------------------- def patch_cached_noise_conditioning(model) -> None: cached_denoise_step_emb = CachedDenoiseStepEmb( model.denoise_step_emb, model.config.scheduler_sigmas ) model.denoise_step_emb = cached_denoise_step_emb for blk in model.transformer.blocks: blk.attn_cond_head = CachedCondHead(blk.attn_cond_head, cached_denoise_step_emb) blk.mlp_cond_head = CachedCondHead(blk.mlp_cond_head, cached_denoise_step_emb) def patch_Attn_merge_qkv(model) -> None: for name, mod in list(model.named_modules()): if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): model.set_submodule(name, MergedQKVAttn(mod, model.config)) def _apply_inference_patches(model) -> None: patch_cached_noise_conditioning(model) patch_Attn_merge_qkv(model) # --------------------------------------------------------------------------- # Model components # --------------------------------------------------------------------------- class CFG(nn.Module): def __init__(self, d_model: int, dropout: float): super().__init__() self.dropout = dropout self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) def forward( self, x: torch.Tensor, is_conditioned: bool | None = None ) -> torch.Tensor: B, L, _ = x.shape null = self.null_emb.expand(B, L, -1) if self.training or is_conditioned is None: if self.dropout == 0.0: return x drop = torch.rand(B, 1, 1, device=x.device) < self.dropout return torch.where(drop, null, x) return x if is_conditioned else null class ControllerInputEmbedding(nn.Module): """Embeds controller inputs (mouse + buttons) into model dimension.""" def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): super().__init__() self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): assert len(mouse.shape) == 3 x = torch.cat((mouse, button, scroll), dim=-1) return self.mlp(x) class MLPFusion(nn.Module): """Fuses per-group conditioning into tokens via split linear projections.""" def __init__(self, d_model: int): super().__init__() self.fc1_x = nn.Linear(d_model, d_model, bias=False) self.fc1_c = nn.Linear(d_model, d_model, bias=False) self.fc2 = nn.Linear(d_model, d_model, bias=False) def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, _, D = x.shape L = cond.shape[1] x = x.reshape(B, L, -1, D) return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( 1, 2 ) class MoEWithoutFBGEMM(nn.Module): """MoE implementation using torch grouped_mm (no fbgemm dependency).""" def __init__(self, config): super().__init__() self.config = config self.top_k = config.moe_top_k moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k d_intermediate = int(config.d_model * moe_mlp_ratio) self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) self.expert_in_proj = nn.Parameter( torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model) ) self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate)) def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: if self.training or torch.is_grad_enabled(): raise NotImplementedError("inference only") orig_shape = x.shape x = x.reshape(-1, orig_shape[-1]) logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) logits_fp32 = logits.float() scores, expert = logits.topk(self.top_k, dim=-1, sorted=False) weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype) expert = expert.flatten() expert_sorted, sort_idx = expert.sort() expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype) offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32) src = sort_idx // self.top_k x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0)) h = F.grouped_mm(x_grouped, self.expert_in_proj.transpose(-2, -1), offs=offsets) h[-1].zero_() if self.config.gated_linear: gate_act, up = h.chunk(2, dim=-1) h = F.silu(gate_act) * up else: h = F.silu(h) y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1] y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1) return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape) class MoE(nn.Module): """MoE implementation using fbgemm optimized kernels.""" def __init__(self, config): super().__init__() self.config = config self.top_k = config.moe_top_k moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k) d_int = int(config.d_model * moe_mlp_ratio) self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) self.expert_in_proj = nn.Parameter( torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model) ) self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int)) def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: if self.training or torch.is_grad_enabled(): raise NotImplementedError("inference only") orig = x.shape x = x.reshape(-1, orig[-1]) logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) logits32 = logits.float() token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k) E = self.expert_in_proj.size(0) offs = token_counts[:E].cumsum(0).to(torch.int32) src = src.to(torch.long) expert_sorted = expert_sorted.to(torch.long) logZ = logits32.logsumexp(-1) w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype) xg = x.index_select(0, torch.cat((src, src[:1]), 0)) h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs) if self.config.gated_linear: ga, up = h.chunk(2, -1) h = F.silu(ga) * up else: h = F.silu(h) yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1] out = torch.zeros_like(x) torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src) return out.reshape(orig) class CondHead(nn.Module): """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" def __init__(self, d_model: int, noise_conditioning: str = "wan", n_cond: int = 3): super().__init__() self.bias_in = ( nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None ) self.cond_proj = nn.ModuleList( [nn.Linear(d_model, d_model, bias=False) for _ in range(n_cond)] ) def forward(self, cond): cond = cond + self.bias_in if self.bias_in is not None else cond h = F.silu(cond) return tuple(p(h) for p in self.cond_proj) # --------------------------------------------------------------------------- # Transformer blocks # --------------------------------------------------------------------------- class WorldDiTBlock(nn.Module): """Single transformer block with self-attention, optional cross-attention, and MLP.""" def __init__( self, d_model, n_heads, mlp_ratio, layer_idx, prompt_conditioning, prompt_conditioning_period, prompt_embedding_dim, ctrl_conditioning_period, noise_conditioning, config, ): super().__init__() self.config = config self.attn = Attn(config, layer_idx) if getattr(config, "moe", False): self.dit_mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config) else: self.dit_mlp = MLP(d_model, d_model * mlp_ratio, d_model) self.attn_cond_head = CondHead(d_model, noise_conditioning, n_cond=3) self.mlp_cond_head = CondHead(d_model, noise_conditioning, n_cond=3) do_prompt_cond = ( prompt_conditioning is not None and layer_idx % prompt_conditioning_period == 0 ) self.prompt_cross_attn = ( CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None ) do_ctrl_cond = ctrl_conditioning_period is not None and layer_idx % ctrl_conditioning_period == 0 self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None): s0, b0, g0 = self.attn_cond_head(cond) s1, b1, g1 = self.mlp_cond_head(cond) residual = x x = ada_rmsnorm(x, s0, b0) x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache) x = ada_gate(x, g0) + residual if self.prompt_cross_attn is not None: x = ( self.prompt_cross_attn( rms_norm(x), context=rms_norm(ctx["prompt_emb"]), context_pad_mask=ctx["prompt_pad_mask"], ) + x ) if self.ctrl_mlpfusion is not None: x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x x = ada_gate(self.dit_mlp(ada_rmsnorm(x, s1, b1)), g1) + x return x, v class WorldDiT(nn.Module): """Stack of WorldDiTBlocks with shared parameters.""" def __init__(self, config): super().__init__() self.config = config self.blocks = nn.ModuleList( [ WorldDiTBlock( d_model=config.d_model, n_heads=config.n_heads, mlp_ratio=config.mlp_ratio, layer_idx=idx, prompt_conditioning=config.prompt_conditioning, prompt_conditioning_period=config.prompt_conditioning_period, prompt_embedding_dim=config.prompt_embedding_dim, ctrl_conditioning_period=config.ctrl_conditioning_period, noise_conditioning=config.noise_conditioning, config=config, ) for idx in range(config.n_layers) ] ) self.rope_angles = OrthoRoPEAngles(config) def forward(self, x, pos_ids, cond, ctx, kv_cache=None): rope_angles = self.rope_angles(pos_ids) v = None for i, block in enumerate(self.blocks): x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache) return x # --------------------------------------------------------------------------- # Top-level model # --------------------------------------------------------------------------- class WorldModel(ModelMixin, ConfigMixin): """ WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. Denoises a frame given: - All previous frames (via KV cache) - The prompt embedding - The controller input embedding - The current noise level """ _supports_gradient_checkpointing = False _keep_in_fp32_modules = ["denoise_step_emb", "rope_angles"] @register_to_config def __init__( self, d_model: int = 2048, n_heads: int = 32, n_kv_heads: int | None = None, n_layers: int = 24, mlp_ratio: int = 4, channels: int = 32, height: int = 16, width: int = 16, patch: tuple = (2, 2), tokens_per_frame: int = 256, n_frames: int = 4096, local_window: int = 16, global_window: int = 128, global_attn_period: int = 4, global_pinned_dilation: int = 8, global_attn_offset: int = 0, value_residual: bool = True, gated_attn: bool = False, n_buttons: int = 256, ctrl_conditioning: str | None = "mlp_fusion", ctrl_conditioning_period: int | None = 3, ctrl_cond_dropout: float = 0.0, prompt_conditioning: str | None = None, prompt_conditioning_period: int = 3, prompt_embedding_dim: int = 2048, prompt_cond_dropout: float = 0.0, noise_conditioning: str = "wan", scheduler_sigmas: list[float] | None = [ 1.0, 0.8609585762023926, 0.729332447052002, 0.3205108940601349, 0.0, ], base_fps: int = 60, causal: bool = True, mlp_gradient_checkpointing: bool = True, block_gradient_checkpointing: bool = True, rope_impl: str = "ortho", moe: bool = False, moe_top_k: int = 2, moe_n_experts: int = 8, moe_mlp_ratio: float | None = None, gated_linear: bool = False, temporal_compression: int = 1, inference_fps: int | None = None, taehv_ae: bool = False, rope_nyquist_frac: float = 0.8, rope_theta: float = 10000.0, ): super().__init__() self.denoise_step_emb = NoiseConditioner(d_model) self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) if self.config.ctrl_conditioning is not None: self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) if self.config.prompt_conditioning is not None: self.prompt_cfg = CFG( self.config.prompt_embedding_dim, self.config.prompt_cond_dropout ) self.transformer = WorldDiT(self.config) self.patch = tuple(patch) C, D = channels, d_model self.patchify = nn.Conv2d( C, D, kernel_size=self.patch, stride=self.patch, bias=False ) self.unpatchify = nn.ConvTranspose2d( D, C, kernel_size=self.patch, stride=self.patch, bias=True ) self.out_norm = AdaLN(d_model) T = tokens_per_frame idx = torch.arange(T, dtype=torch.long) self.register_buffer( "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False ) self.register_buffer( "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False ) self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) def forward( self, x: Tensor, sigma: Tensor, frame_timestamp: Tensor, frame_idx: Tensor | None = None, prompt_emb: Tensor | None = None, prompt_pad_mask: Tensor | None = None, mouse: Tensor | None = None, button: Tensor | None = None, scroll: Tensor | None = None, kv_cache=None, ): B, N, C, H, W = x.shape ph, pw = self.patch assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" Hp, Wp = H // ph, W // pw torch._assert( Hp * Wp == self.config.tokens_per_frame, f"{Hp} * {Wp} != {self.config.tokens_per_frame}", ) torch._assert( B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" ) self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) pos_ids = TensorDict( { "f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None], "t_pos": self._t_pos_1f[None], "y_pos": self._y_pos_1f[None], "x_pos": self._x_pos_1f[None], }, batch_size=[1, self._t_pos_1f.numel()], ) cond = self.denoise_step_emb(sigma) assert button is not None ctx = { "ctrl_emb": self.ctrl_emb(mouse, button, scroll), "prompt_emb": prompt_emb, "prompt_pad_mask": prompt_pad_mask, } D = self.config.d_model x = self.patchify(x.reshape(B * N, C, H, W)) x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") x = self.transformer(x, pos_ids, cond, ctx, kv_cache) x = F.silu(self.out_norm(x, cond)) x = eo.rearrange(x, "b (n hp wp) d -> (b n) d hp wp", n=N, hp=Hp, wp=Wp) x = self.unpatchify(x) x = x.view(B, N, C, H, W) return x def get_active_parameters(self) -> int: total = sum(p.numel() for p in self.parameters()) c = self.config if getattr(c, "moe", False): moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts) total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2) return total def quantize(self, quant_type: str): quantize_model(self, quant_type) def apply_inference_patches(self): _apply_inference_patches(self)