dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
15.7 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""WorldModel transformer for frame generation."""
from typing import Optional, List
import math
import einops as eo
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from tensordict import TensorDict
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attn import Attn, MergedQKVAttn, CrossAttention
from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm
from .quantize import quantize_model
from .cache import CachedDenoiseStepEmb, CachedCondHead
def patch_cached_noise_conditioning(model) -> None:
# Call AFTER: model.to(device="cuda", dtype=torch.bfloat16).eval()
cached_denoise_step_emb = CachedDenoiseStepEmb(
model.denoise_step_emb, model.config.scheduler_sigmas
)
model.denoise_step_emb = cached_denoise_step_emb
for blk in model.transformer.blocks:
blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb)
def patch_Attn_merge_qkv(model) -> None:
for name, mod in list(model.named_modules()):
if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
model.set_submodule(name, MergedQKVAttn(mod, model.config))
def patch_MLPFusion_split(model) -> None:
for name, mod in list(model.named_modules()):
if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion):
model.set_submodule(name, SplitMLPFusion(mod))
def _apply_inference_patches(model) -> None:
patch_cached_noise_conditioning(model)
patch_Attn_merge_qkv(model)
patch_MLPFusion_split(model)
class CFG(nn.Module):
def __init__(self, d_model: int, dropout: float):
super().__init__()
self.dropout = dropout
self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
def forward(
self, x: torch.Tensor, is_conditioned: Optional[bool] = None
) -> torch.Tensor:
"""
x: [B, L, D]
is_conditioned:
- None: training-style random dropout
- bool: whole batch conditioned / unconditioned at sampling
"""
B, L, _ = x.shape
null = self.null_emb.expand(B, L, -1)
# training-style dropout OR unspecified
if self.training or is_conditioned is None:
if self.dropout == 0.0:
return x
drop = torch.rand(B, 1, 1, device=x.device) < self.dropout # [B,1,1]
return torch.where(drop, null, x)
# sampling-time switch
return x if is_conditioned else null
class ControllerInputEmbedding(nn.Module):
"""Embeds controller inputs (mouse + buttons) into model dimension."""
def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
super().__init__()
self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) # mouse velocity (x,y) + scroll sign
def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
assert len(mouse.shape) == 3
x = torch.cat((mouse, button, scroll), dim=-1)
return self.mlp(x)
class MLPFusion(nn.Module):
"""Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond])."""
def __init__(self, d_model: int):
super().__init__()
self.mlp = MLP(2 * d_model, d_model, d_model)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
B, _, D = x.shape
L = cond.shape[1]
Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) # each [D, D]
x = x.view(B, L, -1, D)
h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze(
2
) # broadcast, no repeat/cat
h = F.silu(h)
y = F.linear(h, self.mlp.fc2.weight)
return y.flatten(1, 2)
class SplitMLPFusion(nn.Module):
"""Packed MLPFusion -> split linears (no cat, quant-friendly)."""
def __init__(self, src: MLPFusion):
super().__init__()
D = src.mlp.fc2.in_features
dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype
self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
with torch.no_grad():
Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1)
self.fc1_x.weight.copy_(Wx)
self.fc1_c.weight.copy_(Wc)
self.fc2.weight.copy_(src.mlp.fc2.weight)
self.train(src.training)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
B, _, D = x.shape
L = cond.shape[1]
x = x.reshape(B, L, -1, D)
return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
1, 2
)
class CondHead(nn.Module):
"""Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
n_cond = 6
def __init__(self, d_model: int, noise_conditioning: str = "wan"):
super().__init__()
self.bias_in = (
nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
)
self.cond_proj = nn.ModuleList(
[nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)]
)
def forward(self, cond):
cond = cond + self.bias_in if self.bias_in is not None else cond
h = F.silu(cond)
return tuple(p(h) for p in self.cond_proj)
class WorldDiTBlock(nn.Module):
"""Single transformer block with self-attention, optional cross-attention, and MLP."""
def __init__(
self,
d_model: int,
n_heads: int,
mlp_ratio: int,
layer_idx: int,
prompt_conditioning: Optional[str],
prompt_conditioning_period: int,
prompt_embedding_dim: int,
ctrl_conditioning_period: int,
noise_conditioning: str,
config,
):
super().__init__()
self.config = config
self.attn = Attn(config, layer_idx)
self.mlp = MLP(d_model, d_model * mlp_ratio, d_model)
self.cond_head = CondHead(d_model, noise_conditioning)
do_prompt_cond = (
prompt_conditioning is not None
and layer_idx % prompt_conditioning_period == 0
)
self.prompt_cross_attn = (
CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
)
do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0
self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None):
"""
0) Causal Frame Attention
1) Frame->CTX Cross Attention
2) MLP
"""
s0, b0, g0, s1, b1, g1 = self.cond_head(cond)
# Self / Causal Attention
residual = x
x = ada_rmsnorm(x, s0, b0)
x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache)
x = ada_gate(x, g0) + residual
# Cross Attention Prompt Conditioning
if self.prompt_cross_attn is not None:
x = (
self.prompt_cross_attn(
rms_norm(x),
context=rms_norm(ctx["prompt_emb"]),
context_pad_mask=ctx["prompt_pad_mask"],
)
+ x
)
# MLPFusion Controller Conditioning
if self.ctrl_mlpfusion is not None:
x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
# MLP
x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x
return x, v
class WorldDiT(nn.Module):
"""Stack of WorldDiTBlocks with shared parameters."""
def __init__(self, config):
super().__init__()
self.config = config
self.blocks = nn.ModuleList(
[
WorldDiTBlock(
d_model=config.d_model,
n_heads=config.n_heads,
mlp_ratio=config.mlp_ratio,
layer_idx=idx,
prompt_conditioning=config.prompt_conditioning,
prompt_conditioning_period=config.prompt_conditioning_period,
prompt_embedding_dim=config.prompt_embedding_dim,
ctrl_conditioning_period=config.ctrl_conditioning_period,
noise_conditioning=config.noise_conditioning,
config=config,
)
for idx in range(config.n_layers)
]
)
if config.noise_conditioning in ("dit_air", "wan"):
ref_proj = self.blocks[0].cond_head.cond_proj
for blk in self.blocks[1:]:
for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj):
blk_mod.weight = ref_mod.weight
# Shared RoPE buffers
ref_rope = self.blocks[0].attn.rope
for blk in self.blocks[1:]:
blk.attn.rope = ref_rope
def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
v = None
for i, block in enumerate(self.blocks):
x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache)
return x
class WorldModel(ModelMixin, ConfigMixin):
"""
WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
Denoises a frame given:
- All previous frames (via KV cache)
- The prompt embedding
- The controller input embedding
- The current noise level
"""
_supports_gradient_checkpointing = False
_keep_in_fp32_modules = ["denoise_step_emb", "rope"]
@register_to_config
def __init__(
self,
# Model architecture
d_model: int = 2560,
n_heads: int = 40,
n_kv_heads: Optional[int] = 20,
n_layers: int = 22,
mlp_ratio: int = 5,
channels: int = 16,
height: int = 16,
width: int = 16,
patch: tuple = (2, 2),
tokens_per_frame: int = 256,
n_frames: int = 512,
local_window: int = 16,
global_window: int = 128,
global_attn_period: int = 4,
global_pinned_dilation: int = 8,
global_attn_offset: int = -1,
value_residual: bool = False,
gated_attn: bool = True,
n_buttons: int = 256,
ctrl_conditioning: Optional[str] = "mlp_fusion",
ctrl_conditioning_period: int = 3,
ctrl_cond_dropout: float = 0.0,
prompt_conditioning: Optional[str] = "cross_attention",
prompt_conditioning_period: int = 3,
prompt_embedding_dim: int = 2048,
prompt_cond_dropout: float = 0.0,
noise_conditioning: str = "wan",
scheduler_sigmas: Optional[List[float]] = [
1.0,
0.9483006596565247,
0.8379597067832947,
0.0,
],
base_fps: int = 60,
causal: bool = True,
mlp_gradient_checkpointing: bool = True,
block_gradient_checkpointing: bool = True,
rope_impl: str = "ortho",
):
super().__init__()
self.denoise_step_emb = NoiseConditioner(d_model)
self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
if self.config.ctrl_conditioning is not None:
self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
if self.config.prompt_conditioning is not None:
self.prompt_cfg = CFG(
self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
)
self.transformer = WorldDiT(self.config)
self.patch = tuple(patch)
C, D = channels, d_model
self.patchify = nn.Conv2d(
C, D, kernel_size=self.patch, stride=self.patch, bias=False
)
self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True)
self.out_norm = AdaLN(d_model)
# Cached 1-frame pos_ids (buffers + cached TensorDict view)
T = tokens_per_frame
idx = torch.arange(T, dtype=torch.long)
self.register_buffer(
"_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
)
self.register_buffer(
"_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
)
self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
def forward(
self,
x: Tensor,
sigma: Tensor,
frame_timestamp: Tensor,
prompt_emb: Optional[Tensor] = None,
prompt_pad_mask: Optional[Tensor] = None,
mouse: Optional[Tensor] = None,
button: Optional[Tensor] = None,
scroll: Optional[Tensor] = None,
kv_cache=None,
):
"""
Args:
x: [B, N, C, H, W] - latent frames
sigma: [B, N] - noise levels
frame_timestamp: [B, N] - frame indices
prompt_emb: [B, P, D] - prompt embeddings
prompt_pad_mask: [B, P] - padding mask for prompts
mouse: [B, N, 2] - mouse velocity
button: [B, N, n_buttons] - button states
scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1)
kv_cache: StaticKVCache instance
ctrl_cond: whether to apply controller conditioning (inference only)
prompt_cond: whether to apply prompt conditioning (inference only)
"""
B, N, C, H, W = x.shape
ph, pw = self.patch
assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
Hp, Wp = H // ph, W // pw
torch._assert(
Hp * Wp == self.config.tokens_per_frame,
f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
)
torch._assert(
B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
)
self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
pos_ids = TensorDict(
{
"t_pos": self._t_pos_1f[None],
"y_pos": self._y_pos_1f[None],
"x_pos": self._x_pos_1f[None],
},
batch_size=[1, self._t_pos_1f.numel()],
)
cond = self.denoise_step_emb(sigma) # [B, N, d]
assert button is not None
ctx = {
"ctrl_emb": self.ctrl_emb(mouse, button, scroll),
"prompt_emb": prompt_emb,
"prompt_pad_mask": prompt_pad_mask,
}
D = self.unpatchify.in_features
x = self.patchify(x.reshape(B * N, C, H, W))
x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
x = F.silu(self.out_norm(x, cond))
x = eo.rearrange(
self.unpatchify(x),
"b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)",
n=N,
hp=Hp,
wp=Wp,
ph=ph,
pw=pw,
)
return x
def quantize(self, quant_type: str):
quantize_model(self, quant_type)
def apply_inference_patches(self):
_apply_inference_patches(self)