| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """WorldModel transformer for frame generation. |
| |
| Single-file model containing all building blocks: nn primitives, attention, |
| RoPE, quantization, inference caching, and the top-level WorldModel. |
| """ |
|
|
| import warnings |
|
|
| import einops as eo |
| import torch |
| from torch import nn, Tensor |
| import torch.nn.functional as F |
| from tensordict import TensorDict |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| try: |
| from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling |
| import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter |
| HAS_FBGEMM = True |
| except ImportError: |
| HAS_FBGEMM = False |
|
|
|
|
| |
| |
| |
|
|
| class NoCastModule(torch.nn.Module): |
| """Module that prevents dtype casting during .to() calls.""" |
|
|
| def _apply(self, fn): |
| def keep_dtype(t): |
| old_dtype = t.dtype |
| out = fn(t) |
| if out.dtype is not old_dtype: |
| warnings.warn( |
| f"{self.__class__.__name__}: requested dtype cast ignored; " |
| f"keeping {old_dtype}.", |
| stacklevel=3, |
| ) |
| out = out.to(dtype=old_dtype) |
| return out |
|
|
| return super()._apply(keep_dtype) |
|
|
| def to(self, *args, **kwargs): |
| warn_cast = False |
|
|
| if args and isinstance(args[0], torch.Tensor): |
| ref, *rest = args |
| args = (ref.device, *rest) |
| base = next(self.parameters(), None) or next(self.buffers(), None) |
| if base is not None and ref.dtype is not base.dtype: |
| warn_cast = True |
|
|
| if kwargs.pop("dtype", None) is not None: |
| warn_cast = True |
|
|
| args = tuple(a for a in args if not isinstance(a, torch.dtype)) |
|
|
| if warn_cast: |
| warnings.warn( |
| f"{self.__class__.__name__}.to: requested dtype cast ignored; " |
| "keeping existing dtypes.", |
| stacklevel=2, |
| ) |
|
|
| return super().to(*args, **kwargs) |
|
|
|
|
| def rms_norm(x: torch.Tensor) -> torch.Tensor: |
| """Root mean square layer normalization.""" |
| return F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
| class MLP(nn.Module): |
| """Simple MLP with SiLU activation.""" |
|
|
| def __init__(self, dim_in, dim_middle, dim_out): |
| super().__init__() |
| self.fc1 = nn.Linear(dim_in, dim_middle, bias=False) |
| self.fc2 = nn.Linear(dim_middle, dim_out, bias=False) |
|
|
| def forward(self, x): |
| return self.fc2(F.silu(self.fc1(x))) |
|
|
|
|
| class AdaLN(nn.Module): |
| """Adaptive Layer Normalization.""" |
|
|
| def __init__(self, dim): |
| super().__init__() |
| self.fc = nn.Linear(dim, 2 * dim, bias=False) |
|
|
| def forward(self, x, cond): |
| b, n, d = cond.shape |
| _, nm, _ = x.shape |
| m = nm // n |
|
|
| y = F.silu(cond) |
| ab = self.fc(y) |
| ab = ab.view(b, n, 1, 2 * d) |
| ab = ab.expand(-1, -1, m, -1) |
| ab = ab.reshape(b, nm, 2 * d) |
|
|
| a, b_ = ab.chunk(2, dim=-1) |
| x = rms_norm(x) * (1 + a) + b_ |
| return x |
|
|
|
|
| def ada_rmsnorm(x, scale, bias): |
| """Adaptive RMS normalization with scale and bias.""" |
| x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1)) |
| y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2) |
| return eo.rearrange(y4, "b n m d -> b (n m) d") |
|
|
|
|
| def ada_gate(x, gate): |
| """Apply gating to x with per-frame gates.""" |
| x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1)) |
| return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d") |
|
|
|
|
| class NoiseConditioner(NoCastModule): |
| """Sigma -> logSNR -> Fourier Features -> Dense embedding.""" |
|
|
| def __init__(self, dim, fourier_dim=512, base=10_000.0): |
| super().__init__() |
| assert fourier_dim % 2 == 0 |
| half = fourier_dim // 2 |
| self.freq = nn.Buffer( |
| torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32), |
| persistent=False, |
| ) |
| self.mlp = MLP(fourier_dim, dim * 4, dim) |
|
|
| def forward(self, s, eps=torch.finfo(torch.float32).eps): |
| assert self.freq.dtype == torch.float32 |
| orig_dtype, shape = s.dtype, s.shape |
|
|
| with torch.autocast("cuda", enabled=False): |
| s = s.reshape(-1).float() |
| s = s * 1000 |
|
|
| phase = s[:, None] * self.freq[None, :] |
| emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1) |
| emb = emb * 2**0.5 |
| emb = self.mlp(emb) |
|
|
| return emb.to(orig_dtype).view(*shape, -1) |
|
|
|
|
| |
| |
| |
|
|
| class OrthoRoPEAngles(NoCastModule): |
| """Computes RoPE angles on the fly each forward pass.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| d_head = config.d_model // config.n_heads |
| torch._assert(d_head % 8 == 0, "d_head must be divisible by 8") |
| d_xy, d_t = d_head // 8, d_head // 4 |
|
|
| nyq = float(getattr(config, "rope_nyquist_frac", 0.8)) |
| max_freq = min(self.config.height, self.config.width) * nyq |
| n = (d_xy + 1) // 2 |
| xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy] |
|
|
| theta = float(getattr(config, "rope_theta", 10000.0)) |
| inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t)) |
| inv_t = inv_t.repeat_interleave(2) |
|
|
| self.register_buffer("xy", xy, persistent=False) |
| self.register_buffer("inv_t", inv_t, persistent=False) |
|
|
| @torch.autocast("cuda", enabled=False) |
| def forward(self, pos_ids): |
| if not torch.compiler.is_compiling(): |
| torch._assert( |
| (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width), |
| f"pos_ids out of bounds, {self.config.height}, {self.config.width}" |
| ) |
|
|
| x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0 |
| y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0 |
| t = pos_ids["t_pos"].float() |
|
|
| freqs = torch.cat( |
| (x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t), |
| dim=-1, |
| ) |
| return freqs.cos()[:, None], freqs.sin()[:, None] |
|
|
|
|
| class OrthoRoPE(NoCastModule): |
| """Applies precomputed RoPE angles to input tensors.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| assert not getattr(self.config, "has_audio", False) |
|
|
| @torch.autocast("cuda", enabled=False) |
| def forward(self, x, rope_angles): |
| cos, sin = rope_angles |
| x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1) |
| y0 = x0 * cos - x1 * sin |
| y1 = x1 * cos + x0 * sin |
| return torch.cat((y0, y1), dim=-1).type_as(x) |
|
|
|
|
| class Attn(nn.Module): |
| """Self-attention with RoPE and optional GQA, value residual, and gated attention.""" |
|
|
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
|
|
| self.value_residual = getattr(config, "value_residual", False) |
| if self.value_residual: |
| self.v_lamb = nn.Parameter(torch.tensor(0.5)) |
|
|
| self.n_heads = config.n_heads |
| self.n_kv_heads = getattr(config, "n_kv_heads", None) or config.n_heads |
| self.d_head = config.d_model // self.n_heads |
| assert config.d_model % self.n_heads == 0 |
|
|
| self.enable_gqa = self.n_heads != self.n_kv_heads |
|
|
| self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False) |
| self.k_proj = nn.Linear( |
| config.d_model, self.n_kv_heads * self.d_head, bias=False |
| ) |
| self.v_proj = nn.Linear( |
| config.d_model, self.n_kv_heads * self.d_head, bias=False |
| ) |
| self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
|
|
| self.rope = OrthoRoPE(config) |
|
|
| self.gated_attn = getattr(config, "gated_attn", False) |
| if self.gated_attn: |
| self.gate_proj = nn.Linear( |
| self.n_heads, self.n_heads, bias=False |
| ) |
| nn.init.zeros_(self.gate_proj.weight) |
|
|
| def forward(self, x, pos_ids, rope_angles, v1, kv_cache): |
| from torch.nn.attention.flex_attention import flex_attention |
|
|
| q = eo.rearrange( |
| self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head |
| ) |
| k = eo.rearrange( |
| self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head |
| ) |
| v = eo.rearrange( |
| self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head |
| ) |
|
|
| if self.value_residual: |
| v1 = v if v1 is None else v1 |
| v = torch.lerp(v, v1.view_as(v), self.v_lamb) |
|
|
| q, k = rms_norm(q), rms_norm(k) |
| q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) |
|
|
| k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) |
| y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) |
|
|
| if self.gated_attn: |
| gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) |
| y = y * gates.permute(0, 2, 1).unsqueeze(-1) |
| y = eo.rearrange(y, "b h t d -> b t (h d)") |
| y = self.out_proj(y) |
| return y, v1 |
|
|
|
|
| class MergedQKVAttn(Attn): |
| def __init__(self, src: Attn, config): |
| super().__init__(config, src.layer_idx) |
| self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype) |
| self.load_state_dict( |
| src.state_dict(), strict=False |
| ) |
| self.train(src.training) |
|
|
| self.q_out = self.n_heads * self.d_head |
| self.kv_out = self.n_kv_heads * self.d_head |
|
|
| self.qkv_proj = nn.Linear( |
| self.q_proj.in_features, |
| self.q_out + 2 * self.kv_out, |
| bias=False, |
| device=self.q_proj.weight.device, |
| dtype=self.q_proj.weight.dtype, |
| ) |
| with torch.no_grad(): |
| self.qkv_proj.weight.copy_( |
| torch.cat( |
| [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0 |
| ) |
| ) |
|
|
| del self.q_proj, self.k_proj, self.v_proj |
|
|
| def forward(self, x, pos_ids, rope_angles, v1, kv_cache): |
| from torch.nn.attention.flex_attention import flex_attention |
|
|
| q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1) |
|
|
| B, T = x.shape[:2] |
| q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2) |
| k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) |
| v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) |
|
|
| if self.value_residual: |
| v1 = v if v1 is None else v1 |
| v = torch.lerp(v, v1.view_as(v), self.v_lamb) |
|
|
| q, k = rms_norm(q), rms_norm(k) |
| q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) |
|
|
| k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) |
| y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) |
|
|
| if self.gated_attn: |
| gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) |
| y = y * gates.permute(0, 2, 1).unsqueeze(-1) |
|
|
| y = y.transpose(1, 2).reshape(B, T, -1) |
| y = self.out_proj(y) |
| return y, v1 |
|
|
|
|
| class CrossAttention(nn.Module): |
| """Cross-attention for prompt conditioning.""" |
|
|
| def __init__(self, config, context_dim=None): |
| super().__init__() |
| assert config.d_model % config.n_heads == 0 |
|
|
| self.d_head = config.d_model // config.n_heads |
| self.inner_dim = context_dim or config.d_model |
| assert self.inner_dim % self.d_head == 0 |
| self.n_heads = self.inner_dim // self.d_head |
| self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False) |
| self.k_proj = nn.Linear( |
| context_dim or config.d_model, self.inner_dim, bias=False |
| ) |
| self.v_proj = nn.Linear( |
| context_dim or config.d_model, self.inner_dim, bias=False |
| ) |
|
|
| self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False) |
| self.out_proj.weight.detach().zero_() |
|
|
| def forward(self, x, context, context_pad_mask=None): |
| from torch.nn.attention.flex_attention import flex_attention |
|
|
| q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads) |
| k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads) |
| v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads) |
| q, k = rms_norm(q), rms_norm(k) |
| out = flex_attention(q, k, v) |
| out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1) |
| return self.out_proj(out) |
|
|
|
|
| |
| |
| |
|
|
| def _bf16_u16(x: Tensor) -> Tensor: |
| return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF |
|
|
|
|
| class CachedDenoiseStepEmb(nn.Module): |
| """bf16 sigma -> bf16 embedding via 64k LUT.""" |
|
|
| def __init__(self, base: nn.Module, sigmas: list[float]): |
| super().__init__() |
| device = next(base.parameters()).device |
|
|
| levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) |
| bits = _bf16_u16(levels) |
| if torch.unique(bits).numel() != bits.numel(): |
| raise ValueError( |
| "scheduler_sigmas collide in bf16; caching would be ambiguous" |
| ) |
|
|
| with torch.no_grad(): |
| table = ( |
| base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() |
| ) |
|
|
| lut = torch.full((65536,), -1, device=device, dtype=torch.int32) |
| lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) |
|
|
| self.register_buffer("table", table, persistent=False) |
| self.register_buffer("lut", lut, persistent=False) |
| self.register_buffer( |
| "oob", |
| torch.tensor(bits.numel(), device=device, dtype=torch.int32), |
| persistent=False, |
| ) |
|
|
| def forward(self, sigma: Tensor) -> Tensor: |
| if sigma.dtype is not torch.bfloat16: |
| raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") |
| idx = self.lut[_bf16_u16(sigma)] |
| idx = torch.where(idx >= 0, idx, self.oob) |
| return self.table[idx.to(torch.int64)] |
|
|
|
|
| class CachedCondHead(nn.Module): |
| """bf16 cond -> cached conditioning; invalid cond => OOB index error.""" |
|
|
| def __init__( |
| self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8 |
| ): |
| super().__init__() |
| table = cached_denoise_step_emb.table |
| S, D = table.shape |
|
|
| with torch.no_grad(): |
| emb = table[:, None, :] |
| cache = ( |
| torch.stack([t.squeeze(1) for t in base(emb)], 0) |
| .to(torch.bfloat16) |
| .contiguous() |
| ) |
|
|
| key_dim = None |
| for d in range(min(D, max_key_dims)): |
| b = _bf16_u16(table[:, d]) |
| if torch.unique(b).numel() == S: |
| key_dim = d |
| key_bits = b |
| break |
| if key_dim is None: |
| raise ValueError( |
| "Could not find a unique bf16 key dim for cond->sigma mapping" |
| ) |
|
|
| lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) |
| lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) |
|
|
| self.key_dim = int(key_dim) |
| self.register_buffer("cache", cache, persistent=False) |
| self.register_buffer("lut", lut, persistent=False) |
| self.register_buffer( |
| "oob", |
| torch.tensor(S, device=table.device, dtype=torch.int32), |
| persistent=False, |
| ) |
|
|
| def forward(self, cond: Tensor): |
| if cond.dtype is not torch.bfloat16: |
| raise RuntimeError("CachedCondHead expects cond bf16") |
| idx = self.lut[_bf16_u16(cond[..., self.key_dim])] |
| idx = torch.where(idx >= 0, idx, self.oob) |
| g = self.cache[:, idx.to(torch.int64)] |
| return tuple(g.unbind(0)) |
|
|
|
|
| |
| |
| |
|
|
| QUANTS = [None] |
|
|
| try: |
| from flashinfer import nvfp4_quantize, mm_fp4, SfLayout |
| QUANTS.append("nvfp4") |
| except ImportError: |
| pass |
|
|
|
|
| @torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) |
| def fp4_linear( |
| a_bf16: torch.Tensor, |
| b_fp4_T: torch.Tensor, |
| a_global_sf: torch.Tensor, |
| b_sf_T: torch.Tensor, |
| alpha: torch.Tensor, |
| ) -> torch.Tensor: |
| a_fp4, a_sf = nvfp4_quantize( |
| a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False, |
| ) |
| return mm_fp4( |
| a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass" |
| ) |
|
|
|
|
| @fp4_linear.register_fake |
| def _fp4_linear_fake( |
| a_bf16: torch.Tensor, b_fp4_T: torch.Tensor, |
| a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor, |
| ) -> torch.Tensor: |
| return torch.empty( |
| (a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16 |
| ) |
|
|
|
|
| class FP4Linear(nn.Module): |
| """FP4 Linear layer using FlashInfer's NVFP4 quantization.""" |
|
|
| def __init__(self, lin: nn.Linear): |
| super().__init__() |
| self.in_features = lin.in_features |
| self.out_features = lin.out_features |
| assert self.in_features % 32 == 0 and self.out_features % 32 == 0 |
|
|
| self.weight = nn.Parameter(lin.weight.detach().clone()) |
| self._weight_fp4_T = None |
| self._weight_scales_T = None |
| self._alpha = None |
| self._dummy_scale = None |
| self._weight_global_sf = None |
|
|
| with torch.no_grad(): |
| self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32) |
| weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous() |
| weight_amax = weight_bf16.float().abs().nan_to_num().max() |
| self._weight_global_sf = (1.0) / weight_amax |
| self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale) |
| w_fp4, w_sf = nvfp4_quantize( |
| weight_bf16, self._weight_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False, |
| ) |
| self._weight_fp4_T = w_fp4.t() |
| self._weight_scales_T = w_sf.t() |
|
|
| assert self.weight.is_cuda |
| lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16) |
| fp4_linear(lazy_x, self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_flat = x.reshape(-1, x.shape[-1]) |
| y = fp4_linear( |
| x_flat.to(torch.bfloat16).contiguous(), |
| self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha, |
| ) |
| return y.reshape(x.shape[:-1] + (-1,)) |
|
|
|
|
| class FP8W8A8Linear(nn.Module): |
| __constants__ = ("in_features", "out_features") |
|
|
| def __init__(self, lin: nn.Linear): |
| super().__init__() |
| self.in_features, self.out_features = lin.in_features, lin.out_features |
| f8 = torch.float8_e4m3fn |
| inv = 1.0 / float(torch.finfo(f8).max) |
| self._inv = inv |
| w = lin.weight.detach() |
| ws = (w.abs().amax() * inv).clamp_min(1e-8).float() |
| wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() |
| self.register_buffer("wT", wf8.t()) |
| self.register_buffer("ws", ws) |
| if lin.bias is None: |
| self.bias = None |
| else: |
| self.register_buffer("bias", lin.bias.detach().to(torch.float16)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| s = x.shape |
| x2 = x.reshape(-1, s[-1]) |
| xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() |
| xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous() |
| y = torch._scaled_mm( |
| xf8, self.wT, xs, self.ws, |
| bias=self.bias, out_dtype=torch.float16, use_fast_accum=True, |
| ) |
| return y.reshape(*s[:-1], self.out_features).to(x.dtype) |
|
|
|
|
| class FP8Linear(nn.Module): |
| def __init__(self, lin: nn.Linear): |
| super().__init__() |
| self.in_features, self.out_features = lin.in_features, lin.out_features |
| self.bias = ( |
| nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn)) |
| if lin.bias is not None else None |
| ) |
| w_amax = lin.weight.data.abs().amax() |
| w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn) |
| self.register_buffer("w_amax", w_amax) |
| self.register_buffer("weightT", w.t()) |
| self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous() |
| result = torch._scaled_mm( |
| x_fp8, self.weightT, |
| bias=self.bias, scale_a=self.dummy_scale, scale_b=self.w_amax, |
| out_dtype=torch.bfloat16, use_fast_accum=True, |
| ) |
| return result.reshape(x.shape[:-1] + (-1,)) |
|
|
|
|
| def quantize_model(model: nn.Module, quant: str): |
| if quant is None: |
| return model |
|
|
| def eligible(m: nn.Module) -> bool: |
| w = getattr(m, "weight", None) |
| if not isinstance(m, nn.Linear): |
| return False |
| if getattr(w, "dtype", None) != torch.bfloat16: |
| return False |
| o, k = w.shape |
| return (o % 32 == 0) and (k % 32 == 0) |
|
|
| new_linear = {"w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear}[quant] |
|
|
| for name, child in model.named_children(): |
| setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(child, quant) |
| return model |
|
|
|
|
| |
| |
| |
|
|
| def patch_cached_noise_conditioning(model) -> None: |
| cached_denoise_step_emb = CachedDenoiseStepEmb( |
| model.denoise_step_emb, model.config.scheduler_sigmas |
| ) |
| model.denoise_step_emb = cached_denoise_step_emb |
| for blk in model.transformer.blocks: |
| blk.attn_cond_head = CachedCondHead(blk.attn_cond_head, cached_denoise_step_emb) |
| blk.mlp_cond_head = CachedCondHead(blk.mlp_cond_head, cached_denoise_step_emb) |
|
|
|
|
| def patch_Attn_merge_qkv(model) -> None: |
| for name, mod in list(model.named_modules()): |
| if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): |
| model.set_submodule(name, MergedQKVAttn(mod, model.config)) |
|
|
|
|
| def _apply_inference_patches(model) -> None: |
| patch_cached_noise_conditioning(model) |
| patch_Attn_merge_qkv(model) |
|
|
|
|
| |
| |
| |
|
|
| class CFG(nn.Module): |
| def __init__(self, d_model: int, dropout: float): |
| super().__init__() |
| self.dropout = dropout |
| self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) |
|
|
| def forward( |
| self, x: torch.Tensor, is_conditioned: bool | None = None |
| ) -> torch.Tensor: |
| B, L, _ = x.shape |
| null = self.null_emb.expand(B, L, -1) |
|
|
| if self.training or is_conditioned is None: |
| if self.dropout == 0.0: |
| return x |
| drop = torch.rand(B, 1, 1, device=x.device) < self.dropout |
| return torch.where(drop, null, x) |
|
|
| return x if is_conditioned else null |
|
|
|
|
| class ControllerInputEmbedding(nn.Module): |
| """Embeds controller inputs (mouse + buttons) into model dimension.""" |
|
|
| def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): |
| super().__init__() |
| self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) |
|
|
| def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): |
| assert len(mouse.shape) == 3 |
| x = torch.cat((mouse, button, scroll), dim=-1) |
| return self.mlp(x) |
|
|
|
|
| class MLPFusion(nn.Module): |
| """Fuses per-group conditioning into tokens via split linear projections.""" |
|
|
| def __init__(self, d_model: int): |
| super().__init__() |
| self.fc1_x = nn.Linear(d_model, d_model, bias=False) |
| self.fc1_c = nn.Linear(d_model, d_model, bias=False) |
| self.fc2 = nn.Linear(d_model, d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| B, _, D = x.shape |
| L = cond.shape[1] |
| x = x.reshape(B, L, -1, D) |
| return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( |
| 1, 2 |
| ) |
|
|
|
|
| class MoEWithoutFBGEMM(nn.Module): |
| """MoE implementation using torch grouped_mm (no fbgemm dependency).""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.top_k = config.moe_top_k |
| moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k |
| d_intermediate = int(config.d_model * moe_mlp_ratio) |
| self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) |
| self.expert_in_proj = nn.Parameter( |
| torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model) |
| ) |
| self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate)) |
|
|
| def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: |
| if self.training or torch.is_grad_enabled(): |
| raise NotImplementedError("inference only") |
|
|
| orig_shape = x.shape |
| x = x.reshape(-1, orig_shape[-1]) |
| logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) |
|
|
| logits_fp32 = logits.float() |
| scores, expert = logits.topk(self.top_k, dim=-1, sorted=False) |
| weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype) |
|
|
| expert = expert.flatten() |
| expert_sorted, sort_idx = expert.sort() |
| expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype) |
| offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32) |
|
|
| src = sort_idx // self.top_k |
| x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0)) |
| h = F.grouped_mm(x_grouped, self.expert_in_proj.transpose(-2, -1), offs=offsets) |
| h[-1].zero_() |
|
|
| if self.config.gated_linear: |
| gate_act, up = h.chunk(2, dim=-1) |
| h = F.silu(gate_act) * up |
| else: |
| h = F.silu(h) |
|
|
| y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1] |
| y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1) |
| return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape) |
|
|
|
|
| class MoE(nn.Module): |
| """MoE implementation using fbgemm optimized kernels.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.top_k = config.moe_top_k |
| moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k) |
| d_int = int(config.d_model * moe_mlp_ratio) |
|
|
| self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) |
| self.expert_in_proj = nn.Parameter( |
| torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model) |
| ) |
| self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int)) |
|
|
| def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: |
| if self.training or torch.is_grad_enabled(): |
| raise NotImplementedError("inference only") |
|
|
| orig = x.shape |
| x = x.reshape(-1, orig[-1]) |
| logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) |
|
|
| logits32 = logits.float() |
| token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k) |
|
|
| E = self.expert_in_proj.size(0) |
| offs = token_counts[:E].cumsum(0).to(torch.int32) |
|
|
| src = src.to(torch.long) |
| expert_sorted = expert_sorted.to(torch.long) |
| logZ = logits32.logsumexp(-1) |
| w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype) |
|
|
| xg = x.index_select(0, torch.cat((src, src[:1]), 0)) |
| h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs) |
| if self.config.gated_linear: |
| ga, up = h.chunk(2, -1) |
| h = F.silu(ga) * up |
| else: |
| h = F.silu(h) |
|
|
| yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1] |
| out = torch.zeros_like(x) |
| torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src) |
| return out.reshape(orig) |
|
|
|
|
| class CondHead(nn.Module): |
| """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" |
|
|
| def __init__(self, d_model: int, noise_conditioning: str = "wan", n_cond: int = 3): |
| super().__init__() |
| self.bias_in = ( |
| nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None |
| ) |
| self.cond_proj = nn.ModuleList( |
| [nn.Linear(d_model, d_model, bias=False) for _ in range(n_cond)] |
| ) |
|
|
| def forward(self, cond): |
| cond = cond + self.bias_in if self.bias_in is not None else cond |
| h = F.silu(cond) |
| return tuple(p(h) for p in self.cond_proj) |
|
|
|
|
| |
| |
| |
|
|
| class WorldDiTBlock(nn.Module): |
| """Single transformer block with self-attention, optional cross-attention, and MLP.""" |
|
|
| def __init__( |
| self, d_model, n_heads, mlp_ratio, layer_idx, |
| prompt_conditioning, prompt_conditioning_period, prompt_embedding_dim, |
| ctrl_conditioning_period, noise_conditioning, config, |
| ): |
| super().__init__() |
| self.config = config |
| self.attn = Attn(config, layer_idx) |
| if getattr(config, "moe", False): |
| self.dit_mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config) |
| else: |
| self.dit_mlp = MLP(d_model, d_model * mlp_ratio, d_model) |
| self.attn_cond_head = CondHead(d_model, noise_conditioning, n_cond=3) |
| self.mlp_cond_head = CondHead(d_model, noise_conditioning, n_cond=3) |
|
|
| do_prompt_cond = ( |
| prompt_conditioning is not None |
| and layer_idx % prompt_conditioning_period == 0 |
| ) |
| self.prompt_cross_attn = ( |
| CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None |
| ) |
| do_ctrl_cond = ctrl_conditioning_period is not None and layer_idx % ctrl_conditioning_period == 0 |
| self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None |
|
|
| def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None): |
| s0, b0, g0 = self.attn_cond_head(cond) |
| s1, b1, g1 = self.mlp_cond_head(cond) |
|
|
| residual = x |
| x = ada_rmsnorm(x, s0, b0) |
| x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache) |
| x = ada_gate(x, g0) + residual |
|
|
| if self.prompt_cross_attn is not None: |
| x = ( |
| self.prompt_cross_attn( |
| rms_norm(x), |
| context=rms_norm(ctx["prompt_emb"]), |
| context_pad_mask=ctx["prompt_pad_mask"], |
| ) |
| + x |
| ) |
|
|
| if self.ctrl_mlpfusion is not None: |
| x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x |
|
|
| x = ada_gate(self.dit_mlp(ada_rmsnorm(x, s1, b1)), g1) + x |
|
|
| return x, v |
|
|
|
|
| class WorldDiT(nn.Module): |
| """Stack of WorldDiTBlocks with shared parameters.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.blocks = nn.ModuleList( |
| [ |
| WorldDiTBlock( |
| d_model=config.d_model, |
| n_heads=config.n_heads, |
| mlp_ratio=config.mlp_ratio, |
| layer_idx=idx, |
| prompt_conditioning=config.prompt_conditioning, |
| prompt_conditioning_period=config.prompt_conditioning_period, |
| prompt_embedding_dim=config.prompt_embedding_dim, |
| ctrl_conditioning_period=config.ctrl_conditioning_period, |
| noise_conditioning=config.noise_conditioning, |
| config=config, |
| ) |
| for idx in range(config.n_layers) |
| ] |
| ) |
| self.rope_angles = OrthoRoPEAngles(config) |
|
|
| def forward(self, x, pos_ids, cond, ctx, kv_cache=None): |
| rope_angles = self.rope_angles(pos_ids) |
| v = None |
| for i, block in enumerate(self.blocks): |
| x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class WorldModel(ModelMixin, ConfigMixin): |
| """ |
| WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. |
| |
| Denoises a frame given: |
| - All previous frames (via KV cache) |
| - The prompt embedding |
| - The controller input embedding |
| - The current noise level |
| """ |
|
|
| _supports_gradient_checkpointing = False |
| _keep_in_fp32_modules = ["denoise_step_emb", "rope_angles"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| d_model: int = 2048, |
| n_heads: int = 32, |
| n_kv_heads: int | None = None, |
| n_layers: int = 24, |
| mlp_ratio: int = 4, |
| channels: int = 32, |
| height: int = 16, |
| width: int = 16, |
| patch: tuple = (2, 2), |
| tokens_per_frame: int = 256, |
| n_frames: int = 4096, |
| local_window: int = 16, |
| global_window: int = 128, |
| global_attn_period: int = 4, |
| global_pinned_dilation: int = 8, |
| global_attn_offset: int = 0, |
| value_residual: bool = True, |
| gated_attn: bool = False, |
| n_buttons: int = 256, |
| ctrl_conditioning: str | None = "mlp_fusion", |
| ctrl_conditioning_period: int | None = 3, |
| ctrl_cond_dropout: float = 0.0, |
| prompt_conditioning: str | None = None, |
| prompt_conditioning_period: int = 3, |
| prompt_embedding_dim: int = 2048, |
| prompt_cond_dropout: float = 0.0, |
| noise_conditioning: str = "wan", |
| scheduler_sigmas: list[float] | None = [ |
| 1.0, 0.8609585762023926, 0.729332447052002, 0.3205108940601349, 0.0, |
| ], |
| base_fps: int = 60, |
| causal: bool = True, |
| mlp_gradient_checkpointing: bool = True, |
| block_gradient_checkpointing: bool = True, |
| rope_impl: str = "ortho", |
| moe: bool = False, |
| moe_top_k: int = 2, |
| moe_n_experts: int = 8, |
| moe_mlp_ratio: float | None = None, |
| gated_linear: bool = False, |
| temporal_compression: int = 1, |
| inference_fps: int | None = None, |
| taehv_ae: bool = False, |
| rope_nyquist_frac: float = 0.8, |
| rope_theta: float = 10000.0, |
| ): |
| super().__init__() |
|
|
| self.denoise_step_emb = NoiseConditioner(d_model) |
| self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) |
|
|
| if self.config.ctrl_conditioning is not None: |
| self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) |
| if self.config.prompt_conditioning is not None: |
| self.prompt_cfg = CFG( |
| self.config.prompt_embedding_dim, self.config.prompt_cond_dropout |
| ) |
|
|
| self.transformer = WorldDiT(self.config) |
| self.patch = tuple(patch) |
|
|
| C, D = channels, d_model |
| self.patchify = nn.Conv2d( |
| C, D, kernel_size=self.patch, stride=self.patch, bias=False |
| ) |
| self.unpatchify = nn.ConvTranspose2d( |
| D, C, kernel_size=self.patch, stride=self.patch, bias=True |
| ) |
| self.out_norm = AdaLN(d_model) |
|
|
| T = tokens_per_frame |
| idx = torch.arange(T, dtype=torch.long) |
| self.register_buffer( |
| "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False |
| ) |
| self.register_buffer( |
| "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False |
| ) |
| self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| sigma: Tensor, |
| frame_timestamp: Tensor, |
| frame_idx: Tensor | None = None, |
| prompt_emb: Tensor | None = None, |
| prompt_pad_mask: Tensor | None = None, |
| mouse: Tensor | None = None, |
| button: Tensor | None = None, |
| scroll: Tensor | None = None, |
| kv_cache=None, |
| ): |
| B, N, C, H, W = x.shape |
| ph, pw = self.patch |
| assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" |
| Hp, Wp = H // ph, W // pw |
| torch._assert( |
| Hp * Wp == self.config.tokens_per_frame, |
| f"{Hp} * {Wp} != {self.config.tokens_per_frame}", |
| ) |
|
|
| torch._assert( |
| B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" |
| ) |
| self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) |
| pos_ids = TensorDict( |
| { |
| "f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None], |
| "t_pos": self._t_pos_1f[None], |
| "y_pos": self._y_pos_1f[None], |
| "x_pos": self._x_pos_1f[None], |
| }, |
| batch_size=[1, self._t_pos_1f.numel()], |
| ) |
| cond = self.denoise_step_emb(sigma) |
|
|
| assert button is not None |
| ctx = { |
| "ctrl_emb": self.ctrl_emb(mouse, button, scroll), |
| "prompt_emb": prompt_emb, |
| "prompt_pad_mask": prompt_pad_mask, |
| } |
|
|
| D = self.config.d_model |
| x = self.patchify(x.reshape(B * N, C, H, W)) |
| x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") |
| x = self.transformer(x, pos_ids, cond, ctx, kv_cache) |
| x = F.silu(self.out_norm(x, cond)) |
| x = eo.rearrange(x, "b (n hp wp) d -> (b n) d hp wp", n=N, hp=Hp, wp=Wp) |
| x = self.unpatchify(x) |
| x = x.view(B, N, C, H, W) |
|
|
| return x |
|
|
| def get_active_parameters(self) -> int: |
| total = sum(p.numel() for p in self.parameters()) |
| c = self.config |
| if getattr(c, "moe", False): |
| moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k |
| hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts) |
| total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2) |
| return total |
|
|
| def quantize(self, quant_type: str): |
| quantize_model(self, quant_type) |
|
|
| def apply_inference_patches(self): |
| _apply_inference_patches(self) |
|
|