dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
11.8 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Attention mechanisms for WorldModel transformer."""
import math
import einops as eo
import torch
from torch import nn
from torch.nn.attention.flex_attention import flex_attention
from .nn import rms_norm, NoCastModule
def pixel_frequencies(dim: int, max_freq: float) -> torch.Tensor:
"""Linear frequency spectrum for spatial RoPE (pixel positions).
Matches rotary_embedding_torch RotaryEmbedding(freqs_for='pixel').
Args:
dim: Output dimension (freqs will be repeated to fill this)
max_freq: Maximum frequency (should be below Nyquist)
Returns:
Tensor of shape [dim // 2] with linear frequencies
"""
# Library uses max_freq/2 as the upper bound
return torch.linspace(1.0, max_freq / 2, dim // 2) * math.pi
def lang_frequencies(dim: int) -> torch.Tensor:
"""Geometric frequency spectrum for temporal RoPE (language-style).
Matches rotary_embedding_torch RotaryEmbedding(freqs_for='lang').
Args:
dim: Output dimension (freqs will be repeated to fill this)
Returns:
Tensor of shape [dim // 2] with geometric frequencies
"""
# Library uses 10^(-i/2) pattern
return 10.0 ** (-torch.arange(dim // 2).float() / 2)
class OrthoRoPE(NoCastModule):
"""Rotary Position Embeddings for orthogonal axes: time, height, and width.
- Time: Geometric spectrum (like language models) -- rotates 1/2 of head dim
- Height/Width: Linear spectrum (for pixels) -- rotates 1/4 of head dim each
"""
def __init__(self, config):
super().__init__()
self.config = config
assert not getattr(self.config, "has_audio", False)
# Compute frequencies and store cos/sin buffers
freqs = self._compute_freqs()
self.cos = nn.Buffer(freqs.cos().contiguous(), persistent=False)
self.sin = nn.Buffer(freqs.sin().contiguous(), persistent=False)
def _compute_freqs(self):
"""Compute frequency table for all positions.
Matches the behavior of rotary_embedding_torch.RotaryEmbedding.
The library interleaves frequencies so each freq value is used twice.
"""
config = self.config
H, W, T = config.height, config.width, config.n_frames
head_dim = config.d_model // config.n_heads
# Spatial frequencies (linear spectrum, below Nyquist)
# Library: RotaryEmbedding(dim=head_dim//8) creates head_dim//16 freqs,
# outputs head_dim//8 values (each freq repeated twice)
max_freq = min(H, W) * 0.8
spatial_freqs = pixel_frequencies(head_dim // 8, max_freq) # [D/16]
# Positions in [-1, 1] range
pos_x = torch.linspace(-1 + 1 / W, 1 - 1 / W, W) # [W]
pos_y = torch.linspace(-1 + 1 / H, 1 - 1 / H, H) # [H]
# Spatial frequency embeddings with interleaving (like library)
freqs_x = torch.outer(pos_x, spatial_freqs) # [W, D/16]
freqs_y = torch.outer(pos_y, spatial_freqs) # [H, D/16]
freqs_x = freqs_x.repeat_interleave(2, dim=-1) # [W, D/8]
freqs_y = freqs_y.repeat_interleave(2, dim=-1) # [H, D/8]
# Expand to grid and repeat for all frames
freqs_x = freqs_x[None, :, :].expand(H, W, -1) # [H, W, D/8]
freqs_y = freqs_y[:, None, :].expand(H, W, -1) # [H, W, D/8]
freqs_x = eo.repeat(freqs_x, "h w d -> (t h w) d", t=T) # [T*H*W, D/8]
freqs_y = eo.repeat(freqs_y, "h w d -> (t h w) d", t=T) # [T*H*W, D/8]
# Temporal frequencies (geometric spectrum)
# Library: RotaryEmbedding(dim=head_dim//4) creates head_dim//8 freqs,
# outputs head_dim//4 values (each freq repeated twice)
temporal_freqs = lang_frequencies(head_dim // 4) # [D/8]
pos_t = torch.arange(T).float() # [T]
freqs_t = torch.outer(pos_t, temporal_freqs) # [T, D/8]
freqs_t = freqs_t.repeat_interleave(2, dim=-1) # [T, D/4]
freqs_t = eo.repeat(freqs_t, "t d -> (t h w) d", h=H, w=W) # [T*H*W, D/4]
# Concatenate: [X, Y, T] -> [T*H*W, D/2]
return torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)
def get_angles(self, pos_ids):
"""Look up cos/sin angles for given position IDs."""
t, y, x = pos_ids["t_pos"], pos_ids["y_pos"], pos_ids["x_pos"] # [B,T]
H, W = self.config.height, self.config.width
if not torch.compiler.is_compiling():
torch._assert(
(y.max() < H) & (x.max() < W),
f"pos_ids out of bounds, {y.max()}, {x.max()}",
)
flat = t * (H * W) + y * W + x # [B,T]
idx = flat.reshape(-1).to(torch.long)
cos = self.cos.index_select(0, idx).view(*flat.shape, -1)
sin = self.sin.index_select(0, idx).view(*flat.shape, -1)
return cos[:, None], sin[:, None] # add head dim for broadcast
@torch.autocast("cuda", enabled=False)
def forward(self, x, pos_ids):
assert self.cos.dtype == self.sin.dtype == torch.float32
cos, sin = self.get_angles(pos_ids)
x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
y0 = x0 * cos - x1 * sin
y1 = x1 * cos + x0 * sin
return torch.cat((y0, y1), dim=-1).type_as(x)
class Attn(nn.Module):
"""Self-attention with RoPE and optional GQA, value residual, and gated attention."""
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.value_residual = getattr(config, "value_residual", False)
if self.value_residual:
self.v_lamb = nn.Parameter(torch.tensor(0.5))
self.n_heads = config.n_heads
self.n_kv_heads = getattr(config, "n_kv_heads", config.n_heads)
self.d_head = config.d_model // self.n_heads
assert config.d_model % self.n_heads == 0
self.enable_gqa = self.n_heads != self.n_kv_heads
self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
self.k_proj = nn.Linear(
config.d_model, self.n_kv_heads * self.d_head, bias=False
)
self.v_proj = nn.Linear(
config.d_model, self.n_kv_heads * self.d_head, bias=False
)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.rope = OrthoRoPE(config)
self.gated_attn = getattr(config, "gated_attn", False)
if self.gated_attn:
self.gate_proj = nn.Linear(
self.n_heads, self.n_heads, bias=False
) # sparse attn gate
nn.init.zeros_(self.gate_proj.weight)
def forward(self, x, pos_ids, v1, kv_cache):
# Q, K, V proj -> QK-norm -> RoPE
q = eo.rearrange(
self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
)
k = eo.rearrange(
self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
)
v = eo.rearrange(
self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
)
if self.value_residual:
v1 = v if v1 is None else v1
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
q, k = rms_norm(q), rms_norm(k)
q, k = self.rope(q, pos_ids), self.rope(k, pos_ids)
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
if self.gated_attn:
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
y = eo.rearrange(y, "b h t d -> b t (h d)")
y = self.out_proj(y)
return y, v1
class MergedQKVAttn(Attn):
def __init__(self, src: Attn, config):
super().__init__(config, src.layer_idx) # makes fresh q/k/v/out/etc
self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
self.load_state_dict(
src.state_dict(), strict=False
) # copies trained weights/buffers
self.train(src.training) # preserve train/eval mode
self.q_out = self.n_heads * self.d_head
self.kv_out = self.n_kv_heads * self.d_head
self.qkv_proj = nn.Linear(
self.q_proj.in_features,
self.q_out + 2 * self.kv_out,
bias=False,
device=self.q_proj.weight.device,
dtype=self.q_proj.weight.dtype,
)
with torch.no_grad():
self.qkv_proj.weight.copy_(
torch.cat(
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
)
)
del self.q_proj, self.k_proj, self.v_proj
def forward(self, x, pos_ids, v1, kv_cache):
q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
B, T = x.shape[:2]
q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
if self.value_residual:
v1 = v if v1 is None else v1
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
q, k = rms_norm(q), rms_norm(k)
q, k = self.rope(q, pos_ids), self.rope(k, pos_ids)
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
if self.gated_attn:
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
y = y.transpose(1, 2).reshape(B, T, -1)
y = self.out_proj(y)
return y, v1
class CrossAttention(nn.Module):
"""Cross-attention for prompt conditioning."""
def __init__(self, config, context_dim=None):
super().__init__()
assert config.d_model % config.n_heads == 0
self.d_head = config.d_model // config.n_heads
self.inner_dim = context_dim or config.d_model
assert self.inner_dim % self.d_head == 0
self.n_heads = self.inner_dim // self.d_head
self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
self.k_proj = nn.Linear(
context_dim or config.d_model, self.inner_dim, bias=False
)
self.v_proj = nn.Linear(
context_dim or config.d_model, self.inner_dim, bias=False
)
self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
self.out_proj.weight.detach().zero_()
def forward(self, x, context, context_pad_mask=None):
q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
q, k = rms_norm(q), rms_norm(k)
out = flex_attention(q, k, v)
out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
return self.out_proj(out)