File size: 40,737 Bytes
064b963 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 | # Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""WorldModel transformer for frame generation.
Single-file model containing all building blocks: nn primitives, attention,
RoPE, quantization, inference caching, and the top-level WorldModel.
"""
import warnings
import einops as eo
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from tensordict import TensorDict
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
try:
from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling
import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa
HAS_FBGEMM = True
except ImportError:
HAS_FBGEMM = False
# ---------------------------------------------------------------------------
# NN primitives
# ---------------------------------------------------------------------------
class NoCastModule(torch.nn.Module):
"""Module that prevents dtype casting during .to() calls."""
def _apply(self, fn):
def keep_dtype(t):
old_dtype = t.dtype
out = fn(t)
if out.dtype is not old_dtype:
warnings.warn(
f"{self.__class__.__name__}: requested dtype cast ignored; "
f"keeping {old_dtype}.",
stacklevel=3,
)
out = out.to(dtype=old_dtype)
return out
return super()._apply(keep_dtype)
def to(self, *args, **kwargs):
warn_cast = False
if args and isinstance(args[0], torch.Tensor):
ref, *rest = args
args = (ref.device, *rest)
base = next(self.parameters(), None) or next(self.buffers(), None)
if base is not None and ref.dtype is not base.dtype:
warn_cast = True
if kwargs.pop("dtype", None) is not None:
warn_cast = True
args = tuple(a for a in args if not isinstance(a, torch.dtype))
if warn_cast:
warnings.warn(
f"{self.__class__.__name__}.to: requested dtype cast ignored; "
"keeping existing dtypes.",
stacklevel=2,
)
return super().to(*args, **kwargs)
def rms_norm(x: torch.Tensor) -> torch.Tensor:
"""Root mean square layer normalization."""
return F.rms_norm(x, (x.size(-1),))
class MLP(nn.Module):
"""Simple MLP with SiLU activation."""
def __init__(self, dim_in, dim_middle, dim_out):
super().__init__()
self.fc1 = nn.Linear(dim_in, dim_middle, bias=False)
self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
def forward(self, x):
return self.fc2(F.silu(self.fc1(x)))
class AdaLN(nn.Module):
"""Adaptive Layer Normalization."""
def __init__(self, dim):
super().__init__()
self.fc = nn.Linear(dim, 2 * dim, bias=False)
def forward(self, x, cond):
b, n, d = cond.shape
_, nm, _ = x.shape
m = nm // n
y = F.silu(cond)
ab = self.fc(y) # [b, n, 2d]
ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d]
ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d]
ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d]
a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each
x = rms_norm(x) * (1 + a) + b_
return x
def ada_rmsnorm(x, scale, bias):
"""Adaptive RMS normalization with scale and bias."""
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1))
y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2)
return eo.rearrange(y4, "b n m d -> b (n m) d")
def ada_gate(x, gate):
"""Apply gating to x with per-frame gates."""
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1))
return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d")
class NoiseConditioner(NoCastModule):
"""Sigma -> logSNR -> Fourier Features -> Dense embedding."""
def __init__(self, dim, fourier_dim=512, base=10_000.0):
super().__init__()
assert fourier_dim % 2 == 0
half = fourier_dim // 2
self.freq = nn.Buffer(
torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32),
persistent=False,
)
self.mlp = MLP(fourier_dim, dim * 4, dim)
def forward(self, s, eps=torch.finfo(torch.float32).eps):
assert self.freq.dtype == torch.float32
orig_dtype, shape = s.dtype, s.shape
with torch.autocast("cuda", enabled=False):
s = s.reshape(-1).float()
s = s * 1000
phase = s[:, None] * self.freq[None, :]
emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1)
emb = emb * 2**0.5
emb = self.mlp(emb)
return emb.to(orig_dtype).view(*shape, -1)
# ---------------------------------------------------------------------------
# Attention
# ---------------------------------------------------------------------------
class OrthoRoPEAngles(NoCastModule):
"""Computes RoPE angles on the fly each forward pass."""
def __init__(self, config):
super().__init__()
self.config = config
d_head = config.d_model // config.n_heads
torch._assert(d_head % 8 == 0, "d_head must be divisible by 8")
d_xy, d_t = d_head // 8, d_head // 4
nyq = float(getattr(config, "rope_nyquist_frac", 0.8))
max_freq = min(self.config.height, self.config.width) * nyq
n = (d_xy + 1) // 2
xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy]
theta = float(getattr(config, "rope_theta", 10000.0))
inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t))
inv_t = inv_t.repeat_interleave(2)
self.register_buffer("xy", xy, persistent=False)
self.register_buffer("inv_t", inv_t, persistent=False)
@torch.autocast("cuda", enabled=False)
def forward(self, pos_ids):
if not torch.compiler.is_compiling():
torch._assert(
(pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width),
f"pos_ids out of bounds, {self.config.height}, {self.config.width}"
)
x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0
y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0
t = pos_ids["t_pos"].float()
freqs = torch.cat(
(x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t),
dim=-1,
)
return freqs.cos()[:, None], freqs.sin()[:, None]
class OrthoRoPE(NoCastModule):
"""Applies precomputed RoPE angles to input tensors."""
def __init__(self, config):
super().__init__()
self.config = config
assert not getattr(self.config, "has_audio", False)
@torch.autocast("cuda", enabled=False)
def forward(self, x, rope_angles):
cos, sin = rope_angles
x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
y0 = x0 * cos - x1 * sin
y1 = x1 * cos + x0 * sin
return torch.cat((y0, y1), dim=-1).type_as(x)
class Attn(nn.Module):
"""Self-attention with RoPE and optional GQA, value residual, and gated attention."""
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.value_residual = getattr(config, "value_residual", False)
if self.value_residual:
self.v_lamb = nn.Parameter(torch.tensor(0.5))
self.n_heads = config.n_heads
self.n_kv_heads = getattr(config, "n_kv_heads", None) or config.n_heads
self.d_head = config.d_model // self.n_heads
assert config.d_model % self.n_heads == 0
self.enable_gqa = self.n_heads != self.n_kv_heads
self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
self.k_proj = nn.Linear(
config.d_model, self.n_kv_heads * self.d_head, bias=False
)
self.v_proj = nn.Linear(
config.d_model, self.n_kv_heads * self.d_head, bias=False
)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.rope = OrthoRoPE(config)
self.gated_attn = getattr(config, "gated_attn", False)
if self.gated_attn:
self.gate_proj = nn.Linear(
self.n_heads, self.n_heads, bias=False
)
nn.init.zeros_(self.gate_proj.weight)
def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
from torch.nn.attention.flex_attention import flex_attention
q = eo.rearrange(
self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
)
k = eo.rearrange(
self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
)
v = eo.rearrange(
self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
)
if self.value_residual:
v1 = v if v1 is None else v1
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
q, k = rms_norm(q), rms_norm(k)
q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
if self.gated_attn:
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
y = eo.rearrange(y, "b h t d -> b t (h d)")
y = self.out_proj(y)
return y, v1
class MergedQKVAttn(Attn):
def __init__(self, src: Attn, config):
super().__init__(config, src.layer_idx)
self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
self.load_state_dict(
src.state_dict(), strict=False
)
self.train(src.training)
self.q_out = self.n_heads * self.d_head
self.kv_out = self.n_kv_heads * self.d_head
self.qkv_proj = nn.Linear(
self.q_proj.in_features,
self.q_out + 2 * self.kv_out,
bias=False,
device=self.q_proj.weight.device,
dtype=self.q_proj.weight.dtype,
)
with torch.no_grad():
self.qkv_proj.weight.copy_(
torch.cat(
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
)
)
del self.q_proj, self.k_proj, self.v_proj
def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
from torch.nn.attention.flex_attention import flex_attention
q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
B, T = x.shape[:2]
q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
if self.value_residual:
v1 = v if v1 is None else v1
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
q, k = rms_norm(q), rms_norm(k)
q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
if self.gated_attn:
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
y = y.transpose(1, 2).reshape(B, T, -1)
y = self.out_proj(y)
return y, v1
class CrossAttention(nn.Module):
"""Cross-attention for prompt conditioning."""
def __init__(self, config, context_dim=None):
super().__init__()
assert config.d_model % config.n_heads == 0
self.d_head = config.d_model // config.n_heads
self.inner_dim = context_dim or config.d_model
assert self.inner_dim % self.d_head == 0
self.n_heads = self.inner_dim // self.d_head
self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
self.k_proj = nn.Linear(
context_dim or config.d_model, self.inner_dim, bias=False
)
self.v_proj = nn.Linear(
context_dim or config.d_model, self.inner_dim, bias=False
)
self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
self.out_proj.weight.detach().zero_()
def forward(self, x, context, context_pad_mask=None):
from torch.nn.attention.flex_attention import flex_attention
q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
q, k = rms_norm(q), rms_norm(k)
out = flex_attention(q, k, v)
out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
return self.out_proj(out)
# ---------------------------------------------------------------------------
# Inference caching
# ---------------------------------------------------------------------------
def _bf16_u16(x: Tensor) -> Tensor:
return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
class CachedDenoiseStepEmb(nn.Module):
"""bf16 sigma -> bf16 embedding via 64k LUT."""
def __init__(self, base: nn.Module, sigmas: list[float]):
super().__init__()
device = next(base.parameters()).device
levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16)
bits = _bf16_u16(levels)
if torch.unique(bits).numel() != bits.numel():
raise ValueError(
"scheduler_sigmas collide in bf16; caching would be ambiguous"
)
with torch.no_grad():
table = (
base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
)
lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
self.register_buffer("table", table, persistent=False)
self.register_buffer("lut", lut, persistent=False)
self.register_buffer(
"oob",
torch.tensor(bits.numel(), device=device, dtype=torch.int32),
persistent=False,
)
def forward(self, sigma: Tensor) -> Tensor:
if sigma.dtype is not torch.bfloat16:
raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
idx = self.lut[_bf16_u16(sigma)]
idx = torch.where(idx >= 0, idx, self.oob)
return self.table[idx.to(torch.int64)]
class CachedCondHead(nn.Module):
"""bf16 cond -> cached conditioning; invalid cond => OOB index error."""
def __init__(
self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
):
super().__init__()
table = cached_denoise_step_emb.table
S, D = table.shape
with torch.no_grad():
emb = table[:, None, :]
cache = (
torch.stack([t.squeeze(1) for t in base(emb)], 0)
.to(torch.bfloat16)
.contiguous()
)
key_dim = None
for d in range(min(D, max_key_dims)):
b = _bf16_u16(table[:, d])
if torch.unique(b).numel() == S:
key_dim = d
key_bits = b
break
if key_dim is None:
raise ValueError(
"Could not find a unique bf16 key dim for cond->sigma mapping"
)
lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
self.key_dim = int(key_dim)
self.register_buffer("cache", cache, persistent=False)
self.register_buffer("lut", lut, persistent=False)
self.register_buffer(
"oob",
torch.tensor(S, device=table.device, dtype=torch.int32),
persistent=False,
)
def forward(self, cond: Tensor):
if cond.dtype is not torch.bfloat16:
raise RuntimeError("CachedCondHead expects cond bf16")
idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
idx = torch.where(idx >= 0, idx, self.oob)
g = self.cache[:, idx.to(torch.int64)]
return tuple(g.unbind(0))
# ---------------------------------------------------------------------------
# Quantization
# ---------------------------------------------------------------------------
QUANTS = [None]
try:
from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
QUANTS.append("nvfp4")
except ImportError:
pass
@torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
def fp4_linear(
a_bf16: torch.Tensor,
b_fp4_T: torch.Tensor,
a_global_sf: torch.Tensor,
b_sf_T: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
a_fp4, a_sf = nvfp4_quantize(
a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
)
return mm_fp4(
a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass"
)
@fp4_linear.register_fake
def _fp4_linear_fake(
a_bf16: torch.Tensor, b_fp4_T: torch.Tensor,
a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor,
) -> torch.Tensor:
return torch.empty(
(a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16
)
class FP4Linear(nn.Module):
"""FP4 Linear layer using FlashInfer's NVFP4 quantization."""
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features = lin.in_features
self.out_features = lin.out_features
assert self.in_features % 32 == 0 and self.out_features % 32 == 0
self.weight = nn.Parameter(lin.weight.detach().clone())
self._weight_fp4_T = None
self._weight_scales_T = None
self._alpha = None
self._dummy_scale = None
self._weight_global_sf = None
with torch.no_grad():
self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32)
weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
weight_amax = weight_bf16.float().abs().nan_to_num().max()
self._weight_global_sf = (1.0) / weight_amax
self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
w_fp4, w_sf = nvfp4_quantize(
weight_bf16, self._weight_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
)
self._weight_fp4_T = w_fp4.t()
self._weight_scales_T = w_sf.t()
assert self.weight.is_cuda
lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16)
fp4_linear(lazy_x, self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_flat = x.reshape(-1, x.shape[-1])
y = fp4_linear(
x_flat.to(torch.bfloat16).contiguous(),
self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha,
)
return y.reshape(x.shape[:-1] + (-1,))
class FP8W8A8Linear(nn.Module):
__constants__ = ("in_features", "out_features")
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features
f8 = torch.float8_e4m3fn
inv = 1.0 / float(torch.finfo(f8).max)
self._inv = inv
w = lin.weight.detach()
ws = (w.abs().amax() * inv).clamp_min(1e-8).float()
wf8 = (w / ws.to(w.dtype)).to(f8).contiguous()
self.register_buffer("wT", wf8.t())
self.register_buffer("ws", ws)
if lin.bias is None:
self.bias = None
else:
self.register_buffer("bias", lin.bias.detach().to(torch.float16))
def forward(self, x: torch.Tensor) -> torch.Tensor:
s = x.shape
x2 = x.reshape(-1, s[-1])
xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float()
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
y = torch._scaled_mm(
xf8, self.wT, xs, self.ws,
bias=self.bias, out_dtype=torch.float16, use_fast_accum=True,
)
return y.reshape(*s[:-1], self.out_features).to(x.dtype)
class FP8Linear(nn.Module):
def __init__(self, lin: nn.Linear):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features
self.bias = (
nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
if lin.bias is not None else None
)
w_amax = lin.weight.data.abs().amax()
w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
self.register_buffer("w_amax", w_amax)
self.register_buffer("weightT", w.t())
self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
result = torch._scaled_mm(
x_fp8, self.weightT,
bias=self.bias, scale_a=self.dummy_scale, scale_b=self.w_amax,
out_dtype=torch.bfloat16, use_fast_accum=True,
)
return result.reshape(x.shape[:-1] + (-1,))
def quantize_model(model: nn.Module, quant: str):
if quant is None:
return model
def eligible(m: nn.Module) -> bool:
w = getattr(m, "weight", None)
if not isinstance(m, nn.Linear):
return False
if getattr(w, "dtype", None) != torch.bfloat16:
return False
o, k = w.shape
return (o % 32 == 0) and (k % 32 == 0)
new_linear = {"w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear}[quant]
for name, child in model.named_children():
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(child, quant)
return model
# ---------------------------------------------------------------------------
# Inference patches
# ---------------------------------------------------------------------------
def patch_cached_noise_conditioning(model) -> None:
cached_denoise_step_emb = CachedDenoiseStepEmb(
model.denoise_step_emb, model.config.scheduler_sigmas
)
model.denoise_step_emb = cached_denoise_step_emb
for blk in model.transformer.blocks:
blk.attn_cond_head = CachedCondHead(blk.attn_cond_head, cached_denoise_step_emb)
blk.mlp_cond_head = CachedCondHead(blk.mlp_cond_head, cached_denoise_step_emb)
def patch_Attn_merge_qkv(model) -> None:
for name, mod in list(model.named_modules()):
if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
model.set_submodule(name, MergedQKVAttn(mod, model.config))
def _apply_inference_patches(model) -> None:
patch_cached_noise_conditioning(model)
patch_Attn_merge_qkv(model)
# ---------------------------------------------------------------------------
# Model components
# ---------------------------------------------------------------------------
class CFG(nn.Module):
def __init__(self, d_model: int, dropout: float):
super().__init__()
self.dropout = dropout
self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
def forward(
self, x: torch.Tensor, is_conditioned: bool | None = None
) -> torch.Tensor:
B, L, _ = x.shape
null = self.null_emb.expand(B, L, -1)
if self.training or is_conditioned is None:
if self.dropout == 0.0:
return x
drop = torch.rand(B, 1, 1, device=x.device) < self.dropout
return torch.where(drop, null, x)
return x if is_conditioned else null
class ControllerInputEmbedding(nn.Module):
"""Embeds controller inputs (mouse + buttons) into model dimension."""
def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
super().__init__()
self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model)
def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
assert len(mouse.shape) == 3
x = torch.cat((mouse, button, scroll), dim=-1)
return self.mlp(x)
class MLPFusion(nn.Module):
"""Fuses per-group conditioning into tokens via split linear projections."""
def __init__(self, d_model: int):
super().__init__()
self.fc1_x = nn.Linear(d_model, d_model, bias=False)
self.fc1_c = nn.Linear(d_model, d_model, bias=False)
self.fc2 = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
B, _, D = x.shape
L = cond.shape[1]
x = x.reshape(B, L, -1, D)
return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
1, 2
)
class MoEWithoutFBGEMM(nn.Module):
"""MoE implementation using torch grouped_mm (no fbgemm dependency)."""
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.moe_top_k
moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k
d_intermediate = int(config.d_model * moe_mlp_ratio)
self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
self.expert_in_proj = nn.Parameter(
torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model)
)
self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate))
def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
if self.training or torch.is_grad_enabled():
raise NotImplementedError("inference only")
orig_shape = x.shape
x = x.reshape(-1, orig_shape[-1])
logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
logits_fp32 = logits.float()
scores, expert = logits.topk(self.top_k, dim=-1, sorted=False)
weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype)
expert = expert.flatten()
expert_sorted, sort_idx = expert.sort()
expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype)
offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32)
src = sort_idx // self.top_k
x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0))
h = F.grouped_mm(x_grouped, self.expert_in_proj.transpose(-2, -1), offs=offsets)
h[-1].zero_()
if self.config.gated_linear:
gate_act, up = h.chunk(2, dim=-1)
h = F.silu(gate_act) * up
else:
h = F.silu(h)
y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1]
y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1)
return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape)
class MoE(nn.Module):
"""MoE implementation using fbgemm optimized kernels."""
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.moe_top_k
moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k)
d_int = int(config.d_model * moe_mlp_ratio)
self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
self.expert_in_proj = nn.Parameter(
torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model)
)
self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int))
def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
if self.training or torch.is_grad_enabled():
raise NotImplementedError("inference only")
orig = x.shape
x = x.reshape(-1, orig[-1])
logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
logits32 = logits.float()
token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k)
E = self.expert_in_proj.size(0)
offs = token_counts[:E].cumsum(0).to(torch.int32)
src = src.to(torch.long)
expert_sorted = expert_sorted.to(torch.long)
logZ = logits32.logsumexp(-1)
w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype)
xg = x.index_select(0, torch.cat((src, src[:1]), 0))
h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs)
if self.config.gated_linear:
ga, up = h.chunk(2, -1)
h = F.silu(ga) * up
else:
h = F.silu(h)
yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1]
out = torch.zeros_like(x)
torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src)
return out.reshape(orig)
class CondHead(nn.Module):
"""Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
def __init__(self, d_model: int, noise_conditioning: str = "wan", n_cond: int = 3):
super().__init__()
self.bias_in = (
nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
)
self.cond_proj = nn.ModuleList(
[nn.Linear(d_model, d_model, bias=False) for _ in range(n_cond)]
)
def forward(self, cond):
cond = cond + self.bias_in if self.bias_in is not None else cond
h = F.silu(cond)
return tuple(p(h) for p in self.cond_proj)
# ---------------------------------------------------------------------------
# Transformer blocks
# ---------------------------------------------------------------------------
class WorldDiTBlock(nn.Module):
"""Single transformer block with self-attention, optional cross-attention, and MLP."""
def __init__(
self, d_model, n_heads, mlp_ratio, layer_idx,
prompt_conditioning, prompt_conditioning_period, prompt_embedding_dim,
ctrl_conditioning_period, noise_conditioning, config,
):
super().__init__()
self.config = config
self.attn = Attn(config, layer_idx)
if getattr(config, "moe", False):
self.dit_mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config)
else:
self.dit_mlp = MLP(d_model, d_model * mlp_ratio, d_model)
self.attn_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
self.mlp_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
do_prompt_cond = (
prompt_conditioning is not None
and layer_idx % prompt_conditioning_period == 0
)
self.prompt_cross_attn = (
CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
)
do_ctrl_cond = ctrl_conditioning_period is not None and layer_idx % ctrl_conditioning_period == 0
self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None):
s0, b0, g0 = self.attn_cond_head(cond)
s1, b1, g1 = self.mlp_cond_head(cond)
residual = x
x = ada_rmsnorm(x, s0, b0)
x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache)
x = ada_gate(x, g0) + residual
if self.prompt_cross_attn is not None:
x = (
self.prompt_cross_attn(
rms_norm(x),
context=rms_norm(ctx["prompt_emb"]),
context_pad_mask=ctx["prompt_pad_mask"],
)
+ x
)
if self.ctrl_mlpfusion is not None:
x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
x = ada_gate(self.dit_mlp(ada_rmsnorm(x, s1, b1)), g1) + x
return x, v
class WorldDiT(nn.Module):
"""Stack of WorldDiTBlocks with shared parameters."""
def __init__(self, config):
super().__init__()
self.config = config
self.blocks = nn.ModuleList(
[
WorldDiTBlock(
d_model=config.d_model,
n_heads=config.n_heads,
mlp_ratio=config.mlp_ratio,
layer_idx=idx,
prompt_conditioning=config.prompt_conditioning,
prompt_conditioning_period=config.prompt_conditioning_period,
prompt_embedding_dim=config.prompt_embedding_dim,
ctrl_conditioning_period=config.ctrl_conditioning_period,
noise_conditioning=config.noise_conditioning,
config=config,
)
for idx in range(config.n_layers)
]
)
self.rope_angles = OrthoRoPEAngles(config)
def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
rope_angles = self.rope_angles(pos_ids)
v = None
for i, block in enumerate(self.blocks):
x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache)
return x
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class WorldModel(ModelMixin, ConfigMixin):
"""
WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
Denoises a frame given:
- All previous frames (via KV cache)
- The prompt embedding
- The controller input embedding
- The current noise level
"""
_supports_gradient_checkpointing = False
_keep_in_fp32_modules = ["denoise_step_emb", "rope_angles"]
@register_to_config
def __init__(
self,
d_model: int = 2048,
n_heads: int = 32,
n_kv_heads: int | None = None,
n_layers: int = 24,
mlp_ratio: int = 4,
channels: int = 32,
height: int = 16,
width: int = 16,
patch: tuple = (2, 2),
tokens_per_frame: int = 256,
n_frames: int = 4096,
local_window: int = 16,
global_window: int = 128,
global_attn_period: int = 4,
global_pinned_dilation: int = 8,
global_attn_offset: int = 0,
value_residual: bool = True,
gated_attn: bool = False,
n_buttons: int = 256,
ctrl_conditioning: str | None = "mlp_fusion",
ctrl_conditioning_period: int | None = 3,
ctrl_cond_dropout: float = 0.0,
prompt_conditioning: str | None = None,
prompt_conditioning_period: int = 3,
prompt_embedding_dim: int = 2048,
prompt_cond_dropout: float = 0.0,
noise_conditioning: str = "wan",
scheduler_sigmas: list[float] | None = [
1.0, 0.8609585762023926, 0.729332447052002, 0.3205108940601349, 0.0,
],
base_fps: int = 60,
causal: bool = True,
mlp_gradient_checkpointing: bool = True,
block_gradient_checkpointing: bool = True,
rope_impl: str = "ortho",
moe: bool = False,
moe_top_k: int = 2,
moe_n_experts: int = 8,
moe_mlp_ratio: float | None = None,
gated_linear: bool = False,
temporal_compression: int = 1,
inference_fps: int | None = None,
taehv_ae: bool = False,
rope_nyquist_frac: float = 0.8,
rope_theta: float = 10000.0,
):
super().__init__()
self.denoise_step_emb = NoiseConditioner(d_model)
self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
if self.config.ctrl_conditioning is not None:
self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
if self.config.prompt_conditioning is not None:
self.prompt_cfg = CFG(
self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
)
self.transformer = WorldDiT(self.config)
self.patch = tuple(patch)
C, D = channels, d_model
self.patchify = nn.Conv2d(
C, D, kernel_size=self.patch, stride=self.patch, bias=False
)
self.unpatchify = nn.ConvTranspose2d(
D, C, kernel_size=self.patch, stride=self.patch, bias=True
)
self.out_norm = AdaLN(d_model)
T = tokens_per_frame
idx = torch.arange(T, dtype=torch.long)
self.register_buffer(
"_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
)
self.register_buffer(
"_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
)
self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
def forward(
self,
x: Tensor,
sigma: Tensor,
frame_timestamp: Tensor,
frame_idx: Tensor | None = None,
prompt_emb: Tensor | None = None,
prompt_pad_mask: Tensor | None = None,
mouse: Tensor | None = None,
button: Tensor | None = None,
scroll: Tensor | None = None,
kv_cache=None,
):
B, N, C, H, W = x.shape
ph, pw = self.patch
assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
Hp, Wp = H // ph, W // pw
torch._assert(
Hp * Wp == self.config.tokens_per_frame,
f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
)
torch._assert(
B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
)
self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
pos_ids = TensorDict(
{
"f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None],
"t_pos": self._t_pos_1f[None],
"y_pos": self._y_pos_1f[None],
"x_pos": self._x_pos_1f[None],
},
batch_size=[1, self._t_pos_1f.numel()],
)
cond = self.denoise_step_emb(sigma)
assert button is not None
ctx = {
"ctrl_emb": self.ctrl_emb(mouse, button, scroll),
"prompt_emb": prompt_emb,
"prompt_pad_mask": prompt_pad_mask,
}
D = self.config.d_model
x = self.patchify(x.reshape(B * N, C, H, W))
x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
x = F.silu(self.out_norm(x, cond))
x = eo.rearrange(x, "b (n hp wp) d -> (b n) d hp wp", n=N, hp=Hp, wp=Wp)
x = self.unpatchify(x)
x = x.view(B, N, C, H, W)
return x
def get_active_parameters(self) -> int:
total = sum(p.numel() for p in self.parameters())
c = self.config
if getattr(c, "moe", False):
moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k
hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts)
total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2)
return total
def quantize(self, quant_type: str):
quantize_model(self, quant_type)
def apply_inference_patches(self):
_apply_inference_patches(self)
|