|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""WorldModel transformer for frame generation.""" |
|
|
|
|
|
from typing import Optional, List |
|
|
import math |
|
|
|
|
|
import einops as eo |
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
import torch.nn.functional as F |
|
|
from tensordict import TensorDict |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
|
from .attn import Attn, MergedQKVAttn, CrossAttention |
|
|
from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm |
|
|
from .quantize import quantize_model |
|
|
from .cache import CachedDenoiseStepEmb, CachedCondHead |
|
|
|
|
|
|
|
|
def patch_cached_noise_conditioning(model) -> None: |
|
|
|
|
|
cached_denoise_step_emb = CachedDenoiseStepEmb( |
|
|
model.denoise_step_emb, model.config.scheduler_sigmas |
|
|
) |
|
|
model.denoise_step_emb = cached_denoise_step_emb |
|
|
for blk in model.transformer.blocks: |
|
|
blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb) |
|
|
|
|
|
|
|
|
def patch_Attn_merge_qkv(model) -> None: |
|
|
for name, mod in list(model.named_modules()): |
|
|
if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): |
|
|
model.set_submodule(name, MergedQKVAttn(mod, model.config)) |
|
|
|
|
|
|
|
|
def patch_MLPFusion_split(model) -> None: |
|
|
for name, mod in list(model.named_modules()): |
|
|
if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion): |
|
|
model.set_submodule(name, SplitMLPFusion(mod)) |
|
|
|
|
|
|
|
|
def _apply_inference_patches(model) -> None: |
|
|
patch_cached_noise_conditioning(model) |
|
|
patch_Attn_merge_qkv(model) |
|
|
patch_MLPFusion_split(model) |
|
|
|
|
|
|
|
|
class CFG(nn.Module): |
|
|
def __init__(self, d_model: int, dropout: float): |
|
|
super().__init__() |
|
|
self.dropout = dropout |
|
|
self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) |
|
|
|
|
|
def forward( |
|
|
self, x: torch.Tensor, is_conditioned: Optional[bool] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
x: [B, L, D] |
|
|
is_conditioned: |
|
|
- None: training-style random dropout |
|
|
- bool: whole batch conditioned / unconditioned at sampling |
|
|
""" |
|
|
B, L, _ = x.shape |
|
|
null = self.null_emb.expand(B, L, -1) |
|
|
|
|
|
|
|
|
if self.training or is_conditioned is None: |
|
|
if self.dropout == 0.0: |
|
|
return x |
|
|
drop = torch.rand(B, 1, 1, device=x.device) < self.dropout |
|
|
return torch.where(drop, null, x) |
|
|
|
|
|
|
|
|
return x if is_conditioned else null |
|
|
|
|
|
|
|
|
class ControllerInputEmbedding(nn.Module): |
|
|
"""Embeds controller inputs (mouse + buttons) into model dimension.""" |
|
|
|
|
|
def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): |
|
|
super().__init__() |
|
|
self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) |
|
|
|
|
|
def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): |
|
|
assert len(mouse.shape) == 3 |
|
|
x = torch.cat((mouse, button, scroll), dim=-1) |
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
|
class MLPFusion(nn.Module): |
|
|
"""Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond]).""" |
|
|
|
|
|
def __init__(self, d_model: int): |
|
|
super().__init__() |
|
|
self.mlp = MLP(2 * d_model, d_model, d_model) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
|
|
B, _, D = x.shape |
|
|
L = cond.shape[1] |
|
|
|
|
|
Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) |
|
|
|
|
|
x = x.view(B, L, -1, D) |
|
|
h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze( |
|
|
2 |
|
|
) |
|
|
h = F.silu(h) |
|
|
y = F.linear(h, self.mlp.fc2.weight) |
|
|
return y.flatten(1, 2) |
|
|
|
|
|
|
|
|
class SplitMLPFusion(nn.Module): |
|
|
"""Packed MLPFusion -> split linears (no cat, quant-friendly).""" |
|
|
|
|
|
def __init__(self, src: MLPFusion): |
|
|
super().__init__() |
|
|
D = src.mlp.fc2.in_features |
|
|
dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype |
|
|
|
|
|
self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
|
|
self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
|
|
self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
|
|
|
|
|
with torch.no_grad(): |
|
|
Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1) |
|
|
self.fc1_x.weight.copy_(Wx) |
|
|
self.fc1_c.weight.copy_(Wc) |
|
|
self.fc2.weight.copy_(src.mlp.fc2.weight) |
|
|
|
|
|
self.train(src.training) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
|
|
B, _, D = x.shape |
|
|
L = cond.shape[1] |
|
|
x = x.reshape(B, L, -1, D) |
|
|
return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( |
|
|
1, 2 |
|
|
) |
|
|
|
|
|
|
|
|
class CondHead(nn.Module): |
|
|
"""Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" |
|
|
|
|
|
n_cond = 6 |
|
|
|
|
|
def __init__(self, d_model: int, noise_conditioning: str = "wan"): |
|
|
super().__init__() |
|
|
self.bias_in = ( |
|
|
nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None |
|
|
) |
|
|
self.cond_proj = nn.ModuleList( |
|
|
[nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)] |
|
|
) |
|
|
|
|
|
def forward(self, cond): |
|
|
cond = cond + self.bias_in if self.bias_in is not None else cond |
|
|
h = F.silu(cond) |
|
|
return tuple(p(h) for p in self.cond_proj) |
|
|
|
|
|
|
|
|
class WorldDiTBlock(nn.Module): |
|
|
"""Single transformer block with self-attention, optional cross-attention, and MLP.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
mlp_ratio: int, |
|
|
layer_idx: int, |
|
|
prompt_conditioning: Optional[str], |
|
|
prompt_conditioning_period: int, |
|
|
prompt_embedding_dim: int, |
|
|
ctrl_conditioning_period: int, |
|
|
noise_conditioning: str, |
|
|
config, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.attn = Attn(config, layer_idx) |
|
|
self.mlp = MLP(d_model, d_model * mlp_ratio, d_model) |
|
|
self.cond_head = CondHead(d_model, noise_conditioning) |
|
|
|
|
|
do_prompt_cond = ( |
|
|
prompt_conditioning is not None |
|
|
and layer_idx % prompt_conditioning_period == 0 |
|
|
) |
|
|
self.prompt_cross_attn = ( |
|
|
CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None |
|
|
) |
|
|
do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0 |
|
|
self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None |
|
|
|
|
|
def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): |
|
|
""" |
|
|
0) Causal Frame Attention |
|
|
1) Frame->CTX Cross Attention |
|
|
2) MLP |
|
|
""" |
|
|
s0, b0, g0, s1, b1, g1 = self.cond_head(cond) |
|
|
|
|
|
|
|
|
residual = x |
|
|
x = ada_rmsnorm(x, s0, b0) |
|
|
x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) |
|
|
x = ada_gate(x, g0) + residual |
|
|
|
|
|
|
|
|
if self.prompt_cross_attn is not None: |
|
|
x = ( |
|
|
self.prompt_cross_attn( |
|
|
rms_norm(x), |
|
|
context=rms_norm(ctx["prompt_emb"]), |
|
|
context_pad_mask=ctx["prompt_pad_mask"], |
|
|
) |
|
|
+ x |
|
|
) |
|
|
|
|
|
|
|
|
if self.ctrl_mlpfusion is not None: |
|
|
x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x |
|
|
|
|
|
|
|
|
x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x |
|
|
|
|
|
return x, v |
|
|
|
|
|
|
|
|
class WorldDiT(nn.Module): |
|
|
"""Stack of WorldDiTBlocks with shared parameters.""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
WorldDiTBlock( |
|
|
d_model=config.d_model, |
|
|
n_heads=config.n_heads, |
|
|
mlp_ratio=config.mlp_ratio, |
|
|
layer_idx=idx, |
|
|
prompt_conditioning=config.prompt_conditioning, |
|
|
prompt_conditioning_period=config.prompt_conditioning_period, |
|
|
prompt_embedding_dim=config.prompt_embedding_dim, |
|
|
ctrl_conditioning_period=config.ctrl_conditioning_period, |
|
|
noise_conditioning=config.noise_conditioning, |
|
|
config=config, |
|
|
) |
|
|
for idx in range(config.n_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
if config.noise_conditioning in ("dit_air", "wan"): |
|
|
ref_proj = self.blocks[0].cond_head.cond_proj |
|
|
for blk in self.blocks[1:]: |
|
|
for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): |
|
|
blk_mod.weight = ref_mod.weight |
|
|
|
|
|
|
|
|
ref_rope = self.blocks[0].attn.rope |
|
|
for blk in self.blocks[1:]: |
|
|
blk.attn.rope = ref_rope |
|
|
|
|
|
def forward(self, x, pos_ids, cond, ctx, kv_cache=None): |
|
|
v = None |
|
|
for i, block in enumerate(self.blocks): |
|
|
x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) |
|
|
return x |
|
|
|
|
|
|
|
|
class WorldModel(ModelMixin, ConfigMixin): |
|
|
""" |
|
|
WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. |
|
|
|
|
|
Denoises a frame given: |
|
|
- All previous frames (via KV cache) |
|
|
- The prompt embedding |
|
|
- The controller input embedding |
|
|
- The current noise level |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = False |
|
|
_keep_in_fp32_modules = ["denoise_step_emb", "rope"] |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
d_model: int = 2560, |
|
|
n_heads: int = 40, |
|
|
n_kv_heads: Optional[int] = 20, |
|
|
n_layers: int = 22, |
|
|
mlp_ratio: int = 5, |
|
|
channels: int = 16, |
|
|
height: int = 16, |
|
|
width: int = 16, |
|
|
patch: tuple = (2, 2), |
|
|
tokens_per_frame: int = 256, |
|
|
n_frames: int = 512, |
|
|
local_window: int = 16, |
|
|
global_window: int = 128, |
|
|
global_attn_period: int = 4, |
|
|
global_pinned_dilation: int = 8, |
|
|
global_attn_offset: int = -1, |
|
|
value_residual: bool = False, |
|
|
gated_attn: bool = True, |
|
|
n_buttons: int = 256, |
|
|
ctrl_conditioning: Optional[str] = "mlp_fusion", |
|
|
ctrl_conditioning_period: int = 3, |
|
|
ctrl_cond_dropout: float = 0.0, |
|
|
prompt_conditioning: Optional[str] = "cross_attention", |
|
|
prompt_conditioning_period: int = 3, |
|
|
prompt_embedding_dim: int = 2048, |
|
|
prompt_cond_dropout: float = 0.0, |
|
|
noise_conditioning: str = "wan", |
|
|
scheduler_sigmas: Optional[List[float]] = [ |
|
|
1.0, |
|
|
0.9483006596565247, |
|
|
0.8379597067832947, |
|
|
0.0, |
|
|
], |
|
|
base_fps: int = 60, |
|
|
causal: bool = True, |
|
|
mlp_gradient_checkpointing: bool = True, |
|
|
block_gradient_checkpointing: bool = True, |
|
|
rope_impl: str = "ortho", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.denoise_step_emb = NoiseConditioner(d_model) |
|
|
self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) |
|
|
|
|
|
if self.config.ctrl_conditioning is not None: |
|
|
self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) |
|
|
if self.config.prompt_conditioning is not None: |
|
|
self.prompt_cfg = CFG( |
|
|
self.config.prompt_embedding_dim, self.config.prompt_cond_dropout |
|
|
) |
|
|
|
|
|
self.transformer = WorldDiT(self.config) |
|
|
self.patch = tuple(patch) |
|
|
|
|
|
C, D = channels, d_model |
|
|
self.patchify = nn.Conv2d( |
|
|
C, D, kernel_size=self.patch, stride=self.patch, bias=False |
|
|
) |
|
|
self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True) |
|
|
self.out_norm = AdaLN(d_model) |
|
|
|
|
|
|
|
|
T = tokens_per_frame |
|
|
idx = torch.arange(T, dtype=torch.long) |
|
|
self.register_buffer( |
|
|
"_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False |
|
|
) |
|
|
self.register_buffer( |
|
|
"_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False |
|
|
) |
|
|
self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
sigma: Tensor, |
|
|
frame_timestamp: Tensor, |
|
|
prompt_emb: Optional[Tensor] = None, |
|
|
prompt_pad_mask: Optional[Tensor] = None, |
|
|
mouse: Optional[Tensor] = None, |
|
|
button: Optional[Tensor] = None, |
|
|
scroll: Optional[Tensor] = None, |
|
|
kv_cache=None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
x: [B, N, C, H, W] - latent frames |
|
|
sigma: [B, N] - noise levels |
|
|
frame_timestamp: [B, N] - frame indices |
|
|
prompt_emb: [B, P, D] - prompt embeddings |
|
|
prompt_pad_mask: [B, P] - padding mask for prompts |
|
|
mouse: [B, N, 2] - mouse velocity |
|
|
button: [B, N, n_buttons] - button states |
|
|
scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1) |
|
|
kv_cache: StaticKVCache instance |
|
|
ctrl_cond: whether to apply controller conditioning (inference only) |
|
|
prompt_cond: whether to apply prompt conditioning (inference only) |
|
|
""" |
|
|
B, N, C, H, W = x.shape |
|
|
ph, pw = self.patch |
|
|
assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" |
|
|
Hp, Wp = H // ph, W // pw |
|
|
torch._assert( |
|
|
Hp * Wp == self.config.tokens_per_frame, |
|
|
f"{Hp} * {Wp} != {self.config.tokens_per_frame}", |
|
|
) |
|
|
|
|
|
torch._assert( |
|
|
B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" |
|
|
) |
|
|
self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) |
|
|
pos_ids = TensorDict( |
|
|
{ |
|
|
"t_pos": self._t_pos_1f[None], |
|
|
"y_pos": self._y_pos_1f[None], |
|
|
"x_pos": self._x_pos_1f[None], |
|
|
}, |
|
|
batch_size=[1, self._t_pos_1f.numel()], |
|
|
) |
|
|
cond = self.denoise_step_emb(sigma) |
|
|
|
|
|
assert button is not None |
|
|
ctx = { |
|
|
"ctrl_emb": self.ctrl_emb(mouse, button, scroll), |
|
|
"prompt_emb": prompt_emb, |
|
|
"prompt_pad_mask": prompt_pad_mask, |
|
|
} |
|
|
|
|
|
D = self.unpatchify.in_features |
|
|
x = self.patchify(x.reshape(B * N, C, H, W)) |
|
|
x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") |
|
|
x = self.transformer(x, pos_ids, cond, ctx, kv_cache) |
|
|
x = F.silu(self.out_norm(x, cond)) |
|
|
x = eo.rearrange( |
|
|
self.unpatchify(x), |
|
|
"b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)", |
|
|
n=N, |
|
|
hp=Hp, |
|
|
wp=Wp, |
|
|
ph=ph, |
|
|
pw=pw, |
|
|
) |
|
|
|
|
|
return x |
|
|
|
|
|
def quantize(self, quant_type: str): |
|
|
quantize_model(self, quant_type) |
|
|
|
|
|
def apply_inference_patches(self): |
|
|
_apply_inference_patches(self) |
|
|
|