W1-4B-dLLM-Base / core /model.py
Cynthiawhaletech's picture
Initial release: W1-4B dLLM Base
267f903
"""
Language Diffusion Transformer (DiT for text).
Public open-source version:
- pure PyTorch only
- state_dict key layout kept compatible with the internal model
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo import disable as dynamo_disable
class FallbackRMSNorm(nn.Module):
"""Minimal RMSNorm implementation used by the public model."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x_float = x.float()
norm = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
out = x_float * norm
return (out.to(dtype=x_dtype) * self.weight).to(dtype=x_dtype)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def _rope_cos_sin(
seqlen: int,
dim: int,
theta: float,
*,
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
inv_freq = 1.0 / (
theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=dtype)[None, :, None, :]
sin = emb.sin().to(dtype=dtype)[None, :, None, :]
return cos, sin
class TokenEmbedding(nn.Module):
"""Token embedding (untied from output)."""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.embed(x)
class TimestepEmbedding(nn.Module):
"""Sinusoidal timestep embedding -> conditioning vector."""
def __init__(self, cond_dim: int, freq_dim: int = 256):
super().__init__()
self.freq_dim = freq_dim
self.mlp = nn.Sequential(
nn.Linear(freq_dim, cond_dim, bias=True),
nn.SiLU(),
nn.Linear(cond_dim, cond_dim, bias=True),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
if t.ndim == 2 and t.shape[1] == 1:
t = t.squeeze(-1)
half = self.freq_dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device, dtype=torch.float32) / half
)
args = t[:, None].to(dtype=torch.float32) * freqs[None]
embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
embed = embed.to(dtype=self.mlp[0].weight.dtype)
return self.mlp(embed)
class RotaryEmbedding(nn.Module):
"""Pure PyTorch RoPE implementation."""
def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0):
super().__init__()
self.dim = int(dim)
self.theta = float(theta)
self.max_seq_len = max_seq_len
def apply_bshd(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
cos, sin = _rope_cos_sin(
q.shape[1],
q.shape[-1],
self.theta,
device=q.device,
dtype=q.dtype,
)
q = (q * cos) + (_rotate_half(q) * sin)
k = (k * cos) + (_rotate_half(k) * sin)
return q, k
class Attention(nn.Module):
"""
Multi-head attention with expanded attention dimension.
hidden_size -> attn_dim for Q,K,V -> hidden_size
"""
def __init__(
self,
hidden_size: int,
attn_dim: int,
num_heads: int,
head_dim: int = 128,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.attn_dim = attn_dim
self.attn_drop = float(attn_drop)
self.qkv = nn.Linear(hidden_size, attn_dim * 3, bias=False)
self.proj = nn.Linear(attn_dim, hidden_size, bias=False)
self.proj_drop = nn.Dropout(proj_drop)
@dynamo_disable
def forward(
self,
x: torch.Tensor,
rope: Optional[RotaryEmbedding] = None,
pack: Optional[object] = None,
) -> torch.Tensor:
if pack is not None:
raise RuntimeError("Packed attention is not included in the public torch-only model.")
bsz, seqlen, _ = x.shape
qkv = self.qkv(x).reshape(bsz, seqlen, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
if rope is not None:
q, k = rope.apply_bshd(q, k)
qh = q.permute(0, 2, 1, 3).contiguous()
kh = k.permute(0, 2, 1, 3).contiguous()
vh = v.permute(0, 2, 1, 3).contiguous()
out = F.scaled_dot_product_attention(
qh,
kh,
vh,
dropout_p=self.attn_drop if self.training else 0.0,
is_causal=False,
)
x = out.permute(0, 2, 1, 3).contiguous().reshape(bsz, seqlen, self.attn_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwiGLU(nn.Module):
"""SwiGLU FFN with custom intermediate dimension."""
def __init__(self, hidden_size: int, ffn_dim: int, dropout: float = 0.0):
super().__init__()
self.w12 = nn.Linear(hidden_size, 2 * ffn_dim, bias=False)
self.w3 = nn.Linear(ffn_dim, hidden_size, bias=False)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = self.w12(x).chunk(2, dim=-1)
return self.w3(self.drop(F.silu(x1) * x2))
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""AdaLN modulation."""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class DiTBlock(nn.Module):
"""
DiT transformer block with AdaLN-Zero modulation.
Pre-norm: RMSNorm -> modulate -> Attention/FFN -> gate
"""
def __init__(
self,
hidden_size: int,
attn_dim: int,
ffn_dim: int,
num_heads: int,
head_dim: int = 128,
cond_dim: int = 256,
attn_drop: float = 0.0,
drop: float = 0.0,
):
super().__init__()
self.norm1 = FallbackRMSNorm(hidden_size, eps=1e-6)
self.attn = Attention(hidden_size, attn_dim, num_heads, head_dim, attn_drop, drop)
self.norm2 = FallbackRMSNorm(hidden_size, eps=1e-6)
self.mlp = SwiGLU(hidden_size, ffn_dim, dropout=drop)
self.adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 6 * hidden_size, bias=True),
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
rope: Optional[RotaryEmbedding] = None,
pack: Optional[object] = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN(c).chunk(6, dim=-1)
x = x + gate_msa.unsqueeze(1) * self.attn(
modulate(self.norm1(x), shift_msa, scale_msa),
rope,
pack=pack,
)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalAdaLN(nn.Module):
"""Final AdaLN block that produces prelogits."""
def __init__(self, hidden_size: int, cond_dim: int):
super().__init__()
self.norm = FallbackRMSNorm(hidden_size, eps=1e-6)
self.adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 2 * hidden_size, bias=True),
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN(c).chunk(2, dim=-1)
return modulate(self.norm(x), shift, scale)
class LangDiT(nn.Module):
"""Language Diffusion Transformer."""
def __init__(
self,
vocab_size: int = 64512,
hidden_size: int = 2048,
attn_dim: int = 3072,
ffn_dim: int = 7168,
depth: int = 48,
num_heads: int = 24,
head_dim: int = 128,
max_seq_len: int = 4096,
timestep_freq_dim: int = 256,
rope_theta: float = 10000.0,
cond_dim: int = 256,
dropout: float = 0.0,
attn_dropout: float = 0.0,
):
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.token_embed = TokenEmbedding(vocab_size, hidden_size)
self.time_embed = TimestepEmbedding(cond_dim, freq_dim=timestep_freq_dim)
self.rope = RotaryEmbedding(head_dim, max_seq_len, theta=rope_theta)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size=hidden_size,
attn_dim=attn_dim,
ffn_dim=ffn_dim,
num_heads=num_heads,
head_dim=head_dim,
cond_dim=cond_dim,
attn_drop=attn_dropout,
drop=dropout,
)
for _ in range(depth)
]
)
self.final_ada = FinalAdaLN(hidden_size, cond_dim)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self._init_weights()
def _init_weights(self):
def init_fn(module: nn.Module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
self.apply(init_fn)
for block in self.blocks:
nn.init.zeros_(block.adaLN[-1].weight)
nn.init.zeros_(block.adaLN[-1].bias)
nn.init.zeros_(self.final_ada.adaLN[-1].weight)
nn.init.zeros_(self.final_ada.adaLN[-1].bias)
def forward_hidden(
self,
input_ids: torch.Tensor,
timesteps: torch.Tensor,
*,
pack: Optional[object] = None,
) -> torch.Tensor:
if pack is not None:
raise RuntimeError("Packed attention is not included in the public torch-only model.")
x = self.token_embed(input_ids)
c = self.time_embed(timesteps)
for block in self.blocks:
x = block(x, c, self.rope, pack=None)
return self.final_ada(x, c)
def logits_from_hidden(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.lm_head(hidden_states)
def forward(
self,
input_ids: torch.Tensor,
timesteps: torch.Tensor,
*,
pack: Optional[object] = None,
return_hidden: bool = False,
) -> torch.Tensor:
hidden = self.forward_hidden(input_ids, timesteps, pack=pack)
if return_hidden:
return hidden
return self.logits_from_hidden(hidden)
def create_model(config: dict) -> LangDiT:
"""Create model from config dict."""
model_cfg = config["model"]
return LangDiT(
vocab_size=model_cfg["vocab_size"],
hidden_size=model_cfg["hidden_size"],
attn_dim=model_cfg["attn_dim"],
ffn_dim=model_cfg["ffn_dim"],
depth=model_cfg["depth"],
num_heads=model_cfg["num_heads"],
head_dim=model_cfg["head_dim"],
max_seq_len=model_cfg["max_seq_len"],
timestep_freq_dim=model_cfg.get("timestep_freq_dim", 256),
rope_theta=model_cfg.get("rope_theta", 10000.0),
cond_dim=model_cfg["cond_dim"],
dropout=model_cfg.get("dropout", 0.0),
attn_dropout=model_cfg.get("attn_dropout", 0.0),
)