lapp0's picture
Add diffusers support (#1)
064b963
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""WorldModel transformer for frame generation.
Single-file model containing all building blocks: nn primitives, attention,
RoPE, quantization, inference caching, and the top-level WorldModel.
"""
import warnings
import einops as eo
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from tensordict import TensorDict
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
try:
from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling
import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa
HAS_FBGEMM = True
except ImportError:
HAS_FBGEMM = False
# ---------------------------------------------------------------------------
# NN primitives
# ---------------------------------------------------------------------------
class NoCastModule(torch.nn.Module):
"""Module that prevents dtype casting during .to() calls."""
def _apply(self, fn):
def keep_dtype(t):
old_dtype = t.dtype
out = fn(t)
if out.dtype is not old_dtype:
warnings.warn(
f"{self.__class__.__name__}: requested dtype cast ignored; "
f"keeping {old_dtype}.",
stacklevel=3,
)
out = out.to(dtype=old_dtype)
return out
return super()._apply(keep_dtype)
def to(self, *args, **kwargs):
warn_cast = False
if args and isinstance(args[0], torch.Tensor):
ref, *rest = args
args = (ref.device, *rest)
base = next(self.parameters(), None) or next(self.buffers(), None)
if base is not None and ref.dtype is not base.dtype:
warn_cast = True
if kwargs.pop("dtype", None) is not None:
warn_cast = True
args = tuple(a for a in args if not isinstance(a, torch.dtype))
if warn_cast:
warnings.warn(
f"{self.__class__.__name__}.to: requested dtype cast ignored; "
"keeping existing dtypes.",
stacklevel=2,
)
return super().to(*args, **kwargs)
def rms_norm(x: torch.Tensor) -> torch.Tensor:
"""Root mean square layer normalization."""
return F.rms_norm(x, (x.size(-1),))
class MLP(nn.Module):
"""Simple MLP with SiLU activation."""
def __init__(self, dim_in, dim_middle, dim_out):
super().__init__()
self.fc1 = nn.Linear(dim_in, dim_middle, bias=False)
self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
def forward(self, x):
return self.fc2(F.silu(self.fc1(x)))
class AdaLN(nn.Module):
"""Adaptive Layer Normalization."""
def __init__(self, dim):
super().__init__()
self.fc = nn.Linear(dim, 2 * dim, bias=False)
def forward(self, x, cond):
b, n, d = cond.shape
_, nm, _ = x.shape
m = nm // n
y = F.silu(cond)
ab = self.fc(y) # [b, n, 2d]
ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d]
ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d]
ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d]
a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each
x = rms_norm(x) * (1 + a) + b_
return x
def ada_rmsnorm(x, scale, bias):
"""Adaptive RMS normalization with scale and bias."""
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1))
y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2)
return eo.rearrange(y4, "b n m d -> b (n m) d")
def ada_gate(x, gate):
"""Apply gating to x with per-frame gates."""
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1))
return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d")
class NoiseConditioner(NoCastModule):
"""Sigma -> logSNR -> Fourier Features -> Dense embedding."""
def __init__(self, dim, fourier_dim=512, base=10_000.0):
super().__init__()
assert fourier_dim % 2 == 0
half = fourier_dim // 2
self.freq = nn.Buffer(
torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32),
persistent=False,
)
self.mlp = MLP(fourier_dim, dim * 4, dim)
def forward(self, s, eps=torch.finfo(torch.float32).eps):
assert self.freq.dtype == torch.float32
orig_dtype, shape = s.dtype, s.shape
with torch.autocast("cuda", enabled=False):
s = s.reshape(-1).float()
s = s * 1000
phase = s[:, None] * self.freq[None, :]
emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1)
emb = emb * 2**0.5
emb = self.mlp(emb)
return emb.to(orig_dtype).view(*shape, -1)
# ---------------------------------------------------------------------------
# Attention
# ---------------------------------------------------------------------------
class OrthoRoPEAngles(NoCastModule):
"""Computes RoPE angles on the fly each forward pass."""
def __init__(self, config):
super().__init__()
self.config = config
d_head = config.d_model // config.n_heads
torch._assert(d_head % 8 == 0, "d_head must be divisible by 8")
d_xy, d_t = d_head // 8, d_head // 4
nyq = float(getattr(config, "rope_nyquist_frac", 0.8))
max_freq = min(self.config.height, self.config.width) * nyq
n = (d_xy + 1) // 2
xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy]
theta = float(getattr(config, "rope_theta", 10000.0))
inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t))
inv_t = inv_t.repeat_interleave(2)
self.register_buffer("xy", xy, persistent=False)
self.register_buffer("inv_t", inv_t, persistent=False)
@torch.autocast("cuda", enabled=False)
def forward(self, pos_ids):
if not torch.compiler.is_compiling():
torch._assert(
(pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width),
f"pos_ids out of bounds, {self.config.height}, {self.config.width}"
)
x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0
y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0
t = pos_ids["t_pos"].float()
freqs = torch.cat(
(x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t),
dim=-1,
)
return freqs.cos()[:, None], freqs.sin()[:, None]
class OrthoRoPE(NoCastModule):
"""Applies precomputed RoPE angles to input tensors."""
def __init__(self, config):
super().__init__()
self.config = config
assert not getattr(self.config, "has_audio", False)
@torch.autocast("cuda", enabled=False)
def forward(self, x, rope_angles):
cos, sin = rope_angles
x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
y0 = x0 * cos - x1 * sin
y1 = x1 * cos + x0 * sin
return torch.cat((y0, y1), dim=-1).type_as(x)
class Attn(nn.Module):
"""Self-attention with RoPE and optional GQA, value residual, and gated attention."""
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.value_residual = getattr(config, "value_residual", False)
if self.value_residual:
self.v_lamb = nn.Parameter(torch.tensor(0.5))
self.n_heads = config.n_heads
self.n_kv_heads = getattr(config, "n_kv_heads", None) or config.n_heads
self.d_head = config.d_model // self.n_heads
assert config.d_model % self.n_heads == 0
self.enable_gqa = self.n_heads != self.n_kv_heads
self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
self.k_proj = nn.Linear(
config.d_model, self.n_kv_heads * self.d_head, bias=False
)
self.v_proj = nn.Linear(
config.d_model, self.n_kv_heads * self.d_head, bias=False
)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.rope = OrthoRoPE(config)
self.gated_attn = getattr(config, "gated_attn", False)
if self.gated_attn:
self.gate_proj = nn.Linear(
self.n_heads, self.n_heads, bias=False
)
nn.init.zeros_(self.gate_proj.weight)
def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
from torch.nn.attention.flex_attention import flex_attention
q = eo.rearrange(
self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
)
k = eo.rearrange(
self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
)
v = eo.rearrange(
self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
)
if self.value_residual:
v1 = v if v1 is None else v1
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
q, k = rms_norm(q), rms_norm(k)
q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
if self.gated_attn:
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
y = eo.rearrange(y, "b h t d -> b t (h d)")
y = self.out_proj(y)
return y, v1
class MergedQKVAttn(Attn):
def __init__(self, src: Attn, config):
super().__init__(config, src.layer_idx)
self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
self.load_state_dict(
src.state_dict(), strict=False
)
self.train(src.training)
self.q_out = self.n_heads * self.d_head
self.kv_out = self.n_kv_heads * self.d_head
self.qkv_proj = nn.Linear(
self.q_proj.in_features,
self.q_out + 2 * self.kv_out,
bias=False,
device=self.q_proj.weight.device,
dtype=self.q_proj.weight.dtype,
)
with torch.no_grad():
self.qkv_proj.weight.copy_(
torch.cat(
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
)
)
del self.q_proj, self.k_proj, self.v_proj
def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
from torch.nn.attention.flex_attention import flex_attention
q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
B, T = x.shape[:2]
q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
if self.value_residual:
v1 = v if v1 is None else v1
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
q, k = rms_norm(q), rms_norm(k)
q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
if self.gated_attn:
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
y = y.transpose(1, 2).reshape(B, T, -1)
y = self.out_proj(y)
return y, v1
class CrossAttention(nn.Module):
"""Cross-attention for prompt conditioning."""
def __init__(self, config, context_dim=None):
super().__init__()
assert config.d_model % config.n_heads == 0
self.d_head = config.d_model // config.n_heads
self.inner_dim = context_dim or config.d_model
assert self.inner_dim % self.d_head == 0
self.n_heads = self.inner_dim // self.d_head
self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
self.k_proj = nn.Linear(
context_dim or config.d_model, self.inner_dim, bias=False
)
self.v_proj = nn.Linear(
context_dim or config.d_model, self.inner_dim, bias=False
)
self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
self.out_proj.weight.detach().zero_()
def forward(self, x, context, context_pad_mask=None):
from torch.nn.attention.flex_attention import flex_attention
q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
q, k = rms_norm(q), rms_norm(k)
out = flex_attention(q, k, v)
out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
return self.out_proj(out)
# ---------------------------------------------------------------------------
# Inference caching
# ---------------------------------------------------------------------------
def _bf16_u16(x: Tensor) -> Tensor:
return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
class CachedDenoiseStepEmb(nn.Module):
"""bf16 sigma -> bf16 embedding via 64k LUT."""
def __init__(self, base: nn.Module, sigmas: list[float]):
super().__init__()
device = next(base.parameters()).device
levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16)
bits = _bf16_u16(levels)
if torch.unique(bits).numel() != bits.numel():
raise ValueError(
"scheduler_sigmas collide in bf16; caching would be ambiguous"
)
with torch.no_grad():
table = (
base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
)
lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
self.register_buffer("table", table, persistent=False)
self.register_buffer("lut", lut, persistent=False)
self.register_buffer(
"oob",
torch.tensor(bits.numel(), device=device, dtype=torch.int32),
persistent=False,
)
def forward(self, sigma: Tensor) -> Tensor:
if sigma.dtype is not torch.bfloat16:
raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
idx = self.lut[_bf16_u16(sigma)]
idx = torch.where(idx >= 0, idx, self.oob)
return self.table[idx.to(torch.int64)]
class CachedCondHead(nn.Module):
"""bf16 cond -> cached conditioning; invalid cond => OOB index error."""
def __init__(
self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
):
super().__init__()
table = cached_denoise_step_emb.table
S, D = table.shape
with torch.no_grad():
emb = table[:, None, :]
cache = (
torch.stack([t.squeeze(1) for t in base(emb)], 0)
.to(torch.bfloat16)
.contiguous()
)
key_dim = None
for d in range(min(D, max_key_dims)):
b = _bf16_u16(table[:, d])
if torch.unique(b).numel() == S:
key_dim = d
key_bits = b
break
if key_dim is None:
raise ValueError(
"Could not find a unique bf16 key dim for cond->sigma mapping"
)
lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
self.key_dim = int(key_dim)
self.register_buffer("cache", cache, persistent=False)
self.register_buffer("lut", lut, persistent=False)
self.register_buffer(
"oob",
torch.tensor(S, device=table.device, dtype=torch.int32),
persistent=False,
)
def forward(self, cond: Tensor):
if cond.dtype is not torch.bfloat16:
raise RuntimeError("CachedCondHead expects cond bf16")
idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
idx = torch.where(idx >= 0, idx, self.oob)
g = self.cache[:, idx.to(torch.int64)]
return tuple(g.unbind(0))
# ---------------------------------------------------------------------------
# Quantization
# ---------------------------------------------------------------------------
QUANTS = [None]
try:
from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
QUANTS.append("nvfp4")
except ImportError:
pass
@torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
def fp4_linear(
a_bf16: torch.Tensor,
b_fp4_T: torch.Tensor,
a_global_sf: torch.Tensor,
b_sf_T: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
a_fp4, a_sf = nvfp4_quantize(
a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
)
return mm_fp4(
a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass"
)
@fp4_linear.register_fake
def _fp4_linear_fake(
a_bf16: torch.Tensor, b_fp4_T: torch.Tensor,
a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor,
) -> torch.Tensor:
return torch.empty(
(a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16
)
class FP4Linear(nn.Module):
"""FP4 Linear layer using FlashInfer's NVFP4 quantization."""
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features = lin.in_features
self.out_features = lin.out_features
assert self.in_features % 32 == 0 and self.out_features % 32 == 0
self.weight = nn.Parameter(lin.weight.detach().clone())
self._weight_fp4_T = None
self._weight_scales_T = None
self._alpha = None
self._dummy_scale = None
self._weight_global_sf = None
with torch.no_grad():
self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32)
weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
weight_amax = weight_bf16.float().abs().nan_to_num().max()
self._weight_global_sf = (1.0) / weight_amax
self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
w_fp4, w_sf = nvfp4_quantize(
weight_bf16, self._weight_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
)
self._weight_fp4_T = w_fp4.t()
self._weight_scales_T = w_sf.t()
assert self.weight.is_cuda
lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16)
fp4_linear(lazy_x, self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_flat = x.reshape(-1, x.shape[-1])
y = fp4_linear(
x_flat.to(torch.bfloat16).contiguous(),
self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha,
)
return y.reshape(x.shape[:-1] + (-1,))
class FP8W8A8Linear(nn.Module):
__constants__ = ("in_features", "out_features")
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features
f8 = torch.float8_e4m3fn
inv = 1.0 / float(torch.finfo(f8).max)
self._inv = inv
w = lin.weight.detach()
ws = (w.abs().amax() * inv).clamp_min(1e-8).float()
wf8 = (w / ws.to(w.dtype)).to(f8).contiguous()
self.register_buffer("wT", wf8.t())
self.register_buffer("ws", ws)
if lin.bias is None:
self.bias = None
else:
self.register_buffer("bias", lin.bias.detach().to(torch.float16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
s = x.shape
x2 = x.reshape(-1, s[-1])
xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float()
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
y = torch._scaled_mm(
xf8, self.wT, xs, self.ws,
bias=self.bias, out_dtype=torch.float16, use_fast_accum=True,
)
return y.reshape(*s[:-1], self.out_features).to(x.dtype)
class FP8Linear(nn.Module):
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features
self.bias = (
nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
if lin.bias is not None else None
)
w_amax = lin.weight.data.abs().amax()
w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
self.register_buffer("w_amax", w_amax)
self.register_buffer("weightT", w.t())
self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
result = torch._scaled_mm(
x_fp8, self.weightT,
bias=self.bias, scale_a=self.dummy_scale, scale_b=self.w_amax,
out_dtype=torch.bfloat16, use_fast_accum=True,
)
return result.reshape(x.shape[:-1] + (-1,))
def quantize_model(model: nn.Module, quant: str):
if quant is None:
return model
def eligible(m: nn.Module) -> bool:
w = getattr(m, "weight", None)
if not isinstance(m, nn.Linear):
return False
if getattr(w, "dtype", None) != torch.bfloat16:
return False
o, k = w.shape
return (o % 32 == 0) and (k % 32 == 0)
new_linear = {"w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear}[quant]
for name, child in model.named_children():
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(child, quant)
return model
# ---------------------------------------------------------------------------
# Inference patches
# ---------------------------------------------------------------------------
def patch_cached_noise_conditioning(model) -> None:
cached_denoise_step_emb = CachedDenoiseStepEmb(
model.denoise_step_emb, model.config.scheduler_sigmas
)
model.denoise_step_emb = cached_denoise_step_emb
for blk in model.transformer.blocks:
blk.attn_cond_head = CachedCondHead(blk.attn_cond_head, cached_denoise_step_emb)
blk.mlp_cond_head = CachedCondHead(blk.mlp_cond_head, cached_denoise_step_emb)
def patch_Attn_merge_qkv(model) -> None:
for name, mod in list(model.named_modules()):
if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
model.set_submodule(name, MergedQKVAttn(mod, model.config))
def _apply_inference_patches(model) -> None:
patch_cached_noise_conditioning(model)
patch_Attn_merge_qkv(model)
# ---------------------------------------------------------------------------
# Model components
# ---------------------------------------------------------------------------
class CFG(nn.Module):
def __init__(self, d_model: int, dropout: float):
super().__init__()
self.dropout = dropout
self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
def forward(
self, x: torch.Tensor, is_conditioned: bool | None = None
) -> torch.Tensor:
B, L, _ = x.shape
null = self.null_emb.expand(B, L, -1)
if self.training or is_conditioned is None:
if self.dropout == 0.0:
return x
drop = torch.rand(B, 1, 1, device=x.device) < self.dropout
return torch.where(drop, null, x)
return x if is_conditioned else null
class ControllerInputEmbedding(nn.Module):
"""Embeds controller inputs (mouse + buttons) into model dimension."""
def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
super().__init__()
self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model)
def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
assert len(mouse.shape) == 3
x = torch.cat((mouse, button, scroll), dim=-1)
return self.mlp(x)
class MLPFusion(nn.Module):
"""Fuses per-group conditioning into tokens via split linear projections."""
def __init__(self, d_model: int):
super().__init__()
self.fc1_x = nn.Linear(d_model, d_model, bias=False)
self.fc1_c = nn.Linear(d_model, d_model, bias=False)
self.fc2 = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
B, _, D = x.shape
L = cond.shape[1]
x = x.reshape(B, L, -1, D)
return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
1, 2
)
class MoEWithoutFBGEMM(nn.Module):
"""MoE implementation using torch grouped_mm (no fbgemm dependency)."""
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.moe_top_k
moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k
d_intermediate = int(config.d_model * moe_mlp_ratio)
self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
self.expert_in_proj = nn.Parameter(
torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model)
)
self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate))
def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
if self.training or torch.is_grad_enabled():
raise NotImplementedError("inference only")
orig_shape = x.shape
x = x.reshape(-1, orig_shape[-1])
logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
logits_fp32 = logits.float()
scores, expert = logits.topk(self.top_k, dim=-1, sorted=False)
weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype)
expert = expert.flatten()
expert_sorted, sort_idx = expert.sort()
expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype)
offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32)
src = sort_idx // self.top_k
x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0))
h = F.grouped_mm(x_grouped, self.expert_in_proj.transpose(-2, -1), offs=offsets)
h[-1].zero_()
if self.config.gated_linear:
gate_act, up = h.chunk(2, dim=-1)
h = F.silu(gate_act) * up
else:
h = F.silu(h)
y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1]
y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1)
return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape)
class MoE(nn.Module):
"""MoE implementation using fbgemm optimized kernels."""
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.moe_top_k
moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k)
d_int = int(config.d_model * moe_mlp_ratio)
self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
self.expert_in_proj = nn.Parameter(
torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model)
)
self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int))
def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
if self.training or torch.is_grad_enabled():
raise NotImplementedError("inference only")
orig = x.shape
x = x.reshape(-1, orig[-1])
logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
logits32 = logits.float()
token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k)
E = self.expert_in_proj.size(0)
offs = token_counts[:E].cumsum(0).to(torch.int32)
src = src.to(torch.long)
expert_sorted = expert_sorted.to(torch.long)
logZ = logits32.logsumexp(-1)
w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype)
xg = x.index_select(0, torch.cat((src, src[:1]), 0))
h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs)
if self.config.gated_linear:
ga, up = h.chunk(2, -1)
h = F.silu(ga) * up
else:
h = F.silu(h)
yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1]
out = torch.zeros_like(x)
torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src)
return out.reshape(orig)
class CondHead(nn.Module):
"""Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
def __init__(self, d_model: int, noise_conditioning: str = "wan", n_cond: int = 3):
super().__init__()
self.bias_in = (
nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
)
self.cond_proj = nn.ModuleList(
[nn.Linear(d_model, d_model, bias=False) for _ in range(n_cond)]
)
def forward(self, cond):
cond = cond + self.bias_in if self.bias_in is not None else cond
h = F.silu(cond)
return tuple(p(h) for p in self.cond_proj)
# ---------------------------------------------------------------------------
# Transformer blocks
# ---------------------------------------------------------------------------
class WorldDiTBlock(nn.Module):
"""Single transformer block with self-attention, optional cross-attention, and MLP."""
def __init__(
self, d_model, n_heads, mlp_ratio, layer_idx,
prompt_conditioning, prompt_conditioning_period, prompt_embedding_dim,
ctrl_conditioning_period, noise_conditioning, config,
):
super().__init__()
self.config = config
self.attn = Attn(config, layer_idx)
if getattr(config, "moe", False):
self.dit_mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config)
else:
self.dit_mlp = MLP(d_model, d_model * mlp_ratio, d_model)
self.attn_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
self.mlp_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
do_prompt_cond = (
prompt_conditioning is not None
and layer_idx % prompt_conditioning_period == 0
)
self.prompt_cross_attn = (
CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
)
do_ctrl_cond = ctrl_conditioning_period is not None and layer_idx % ctrl_conditioning_period == 0
self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None):
s0, b0, g0 = self.attn_cond_head(cond)
s1, b1, g1 = self.mlp_cond_head(cond)
residual = x
x = ada_rmsnorm(x, s0, b0)
x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache)
x = ada_gate(x, g0) + residual
if self.prompt_cross_attn is not None:
x = (
self.prompt_cross_attn(
rms_norm(x),
context=rms_norm(ctx["prompt_emb"]),
context_pad_mask=ctx["prompt_pad_mask"],
)
+ x
)
if self.ctrl_mlpfusion is not None:
x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
x = ada_gate(self.dit_mlp(ada_rmsnorm(x, s1, b1)), g1) + x
return x, v
class WorldDiT(nn.Module):
"""Stack of WorldDiTBlocks with shared parameters."""
def __init__(self, config):
super().__init__()
self.config = config
self.blocks = nn.ModuleList(
[
WorldDiTBlock(
d_model=config.d_model,
n_heads=config.n_heads,
mlp_ratio=config.mlp_ratio,
layer_idx=idx,
prompt_conditioning=config.prompt_conditioning,
prompt_conditioning_period=config.prompt_conditioning_period,
prompt_embedding_dim=config.prompt_embedding_dim,
ctrl_conditioning_period=config.ctrl_conditioning_period,
noise_conditioning=config.noise_conditioning,
config=config,
)
for idx in range(config.n_layers)
]
)
self.rope_angles = OrthoRoPEAngles(config)
def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
rope_angles = self.rope_angles(pos_ids)
v = None
for i, block in enumerate(self.blocks):
x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache)
return x
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class WorldModel(ModelMixin, ConfigMixin):
"""
WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
Denoises a frame given:
- All previous frames (via KV cache)
- The prompt embedding
- The controller input embedding
- The current noise level
"""
_supports_gradient_checkpointing = False
_keep_in_fp32_modules = ["denoise_step_emb", "rope_angles"]
@register_to_config
def __init__(
self,
d_model: int = 2048,
n_heads: int = 32,
n_kv_heads: int | None = None,
n_layers: int = 24,
mlp_ratio: int = 4,
channels: int = 32,
height: int = 16,
width: int = 16,
patch: tuple = (2, 2),
tokens_per_frame: int = 256,
n_frames: int = 4096,
local_window: int = 16,
global_window: int = 128,
global_attn_period: int = 4,
global_pinned_dilation: int = 8,
global_attn_offset: int = 0,
value_residual: bool = True,
gated_attn: bool = False,
n_buttons: int = 256,
ctrl_conditioning: str | None = "mlp_fusion",
ctrl_conditioning_period: int | None = 3,
ctrl_cond_dropout: float = 0.0,
prompt_conditioning: str | None = None,
prompt_conditioning_period: int = 3,
prompt_embedding_dim: int = 2048,
prompt_cond_dropout: float = 0.0,
noise_conditioning: str = "wan",
scheduler_sigmas: list[float] | None = [
1.0, 0.8609585762023926, 0.729332447052002, 0.3205108940601349, 0.0,
],
base_fps: int = 60,
causal: bool = True,
mlp_gradient_checkpointing: bool = True,
block_gradient_checkpointing: bool = True,
rope_impl: str = "ortho",
moe: bool = False,
moe_top_k: int = 2,
moe_n_experts: int = 8,
moe_mlp_ratio: float | None = None,
gated_linear: bool = False,
temporal_compression: int = 1,
inference_fps: int | None = None,
taehv_ae: bool = False,
rope_nyquist_frac: float = 0.8,
rope_theta: float = 10000.0,
):
super().__init__()
self.denoise_step_emb = NoiseConditioner(d_model)
self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
if self.config.ctrl_conditioning is not None:
self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
if self.config.prompt_conditioning is not None:
self.prompt_cfg = CFG(
self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
)
self.transformer = WorldDiT(self.config)
self.patch = tuple(patch)
C, D = channels, d_model
self.patchify = nn.Conv2d(
C, D, kernel_size=self.patch, stride=self.patch, bias=False
)
self.unpatchify = nn.ConvTranspose2d(
D, C, kernel_size=self.patch, stride=self.patch, bias=True
)
self.out_norm = AdaLN(d_model)
T = tokens_per_frame
idx = torch.arange(T, dtype=torch.long)
self.register_buffer(
"_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
)
self.register_buffer(
"_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
)
self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
def forward(
self,
x: Tensor,
sigma: Tensor,
frame_timestamp: Tensor,
frame_idx: Tensor | None = None,
prompt_emb: Tensor | None = None,
prompt_pad_mask: Tensor | None = None,
mouse: Tensor | None = None,
button: Tensor | None = None,
scroll: Tensor | None = None,
kv_cache=None,
):
B, N, C, H, W = x.shape
ph, pw = self.patch
assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
Hp, Wp = H // ph, W // pw
torch._assert(
Hp * Wp == self.config.tokens_per_frame,
f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
)
torch._assert(
B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
)
self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
pos_ids = TensorDict(
{
"f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None],
"t_pos": self._t_pos_1f[None],
"y_pos": self._y_pos_1f[None],
"x_pos": self._x_pos_1f[None],
},
batch_size=[1, self._t_pos_1f.numel()],
)
cond = self.denoise_step_emb(sigma)
assert button is not None
ctx = {
"ctrl_emb": self.ctrl_emb(mouse, button, scroll),
"prompt_emb": prompt_emb,
"prompt_pad_mask": prompt_pad_mask,
}
D = self.config.d_model
x = self.patchify(x.reshape(B * N, C, H, W))
x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
x = F.silu(self.out_norm(x, cond))
x = eo.rearrange(x, "b (n hp wp) d -> (b n) d hp wp", n=N, hp=Hp, wp=Wp)
x = self.unpatchify(x)
x = x.view(B, N, C, H, W)
return x
def get_active_parameters(self) -> int:
total = sum(p.numel() for p in self.parameters())
c = self.config
if getattr(c, "moe", False):
moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k
hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts)
total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2)
return total
def quantize(self, quant_type: str):
quantize_model(self, quant_type)
def apply_inference_patches(self):
_apply_inference_patches(self)