| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Attention mechanisms for WorldModel transformer.""" |
| |
|
| | import math |
| |
|
| | import einops as eo |
| | import torch |
| | from torch import nn |
| | from torch.nn.attention.flex_attention import flex_attention |
| |
|
| | from .nn import rms_norm, NoCastModule |
| |
|
| |
|
| | def pixel_frequencies(dim: int, max_freq: float) -> torch.Tensor: |
| | """Linear frequency spectrum for spatial RoPE (pixel positions). |
| | |
| | Matches rotary_embedding_torch RotaryEmbedding(freqs_for='pixel'). |
| | |
| | Args: |
| | dim: Output dimension (freqs will be repeated to fill this) |
| | max_freq: Maximum frequency (should be below Nyquist) |
| | |
| | Returns: |
| | Tensor of shape [dim // 2] with linear frequencies |
| | """ |
| | |
| | return torch.linspace(1.0, max_freq / 2, dim // 2) * math.pi |
| |
|
| |
|
| | def lang_frequencies(dim: int) -> torch.Tensor: |
| | """Geometric frequency spectrum for temporal RoPE (language-style). |
| | |
| | Matches rotary_embedding_torch RotaryEmbedding(freqs_for='lang'). |
| | |
| | Args: |
| | dim: Output dimension (freqs will be repeated to fill this) |
| | |
| | Returns: |
| | Tensor of shape [dim // 2] with geometric frequencies |
| | """ |
| | |
| | return 10.0 ** (-torch.arange(dim // 2).float() / 2) |
| |
|
| |
|
| | class OrthoRoPE(NoCastModule): |
| | """Rotary Position Embeddings for orthogonal axes: time, height, and width. |
| | |
| | - Time: Geometric spectrum (like language models) -- rotates 1/2 of head dim |
| | - Height/Width: Linear spectrum (for pixels) -- rotates 1/4 of head dim each |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | assert not getattr(self.config, "has_audio", False) |
| |
|
| | |
| | freqs = self._compute_freqs() |
| | self.cos = nn.Buffer(freqs.cos().contiguous(), persistent=False) |
| | self.sin = nn.Buffer(freqs.sin().contiguous(), persistent=False) |
| |
|
| | def _compute_freqs(self): |
| | """Compute frequency table for all positions. |
| | |
| | Matches the behavior of rotary_embedding_torch.RotaryEmbedding. |
| | The library interleaves frequencies so each freq value is used twice. |
| | """ |
| | config = self.config |
| | H, W, T = config.height, config.width, config.n_frames |
| | head_dim = config.d_model // config.n_heads |
| |
|
| | |
| | |
| | |
| | max_freq = min(H, W) * 0.8 |
| | spatial_freqs = pixel_frequencies(head_dim // 8, max_freq) |
| |
|
| | |
| | pos_x = torch.linspace(-1 + 1 / W, 1 - 1 / W, W) |
| | pos_y = torch.linspace(-1 + 1 / H, 1 - 1 / H, H) |
| |
|
| | |
| | freqs_x = torch.outer(pos_x, spatial_freqs) |
| | freqs_y = torch.outer(pos_y, spatial_freqs) |
| | freqs_x = freqs_x.repeat_interleave(2, dim=-1) |
| | freqs_y = freqs_y.repeat_interleave(2, dim=-1) |
| |
|
| | |
| | freqs_x = freqs_x[None, :, :].expand(H, W, -1) |
| | freqs_y = freqs_y[:, None, :].expand(H, W, -1) |
| |
|
| | freqs_x = eo.repeat(freqs_x, "h w d -> (t h w) d", t=T) |
| | freqs_y = eo.repeat(freqs_y, "h w d -> (t h w) d", t=T) |
| |
|
| | |
| | |
| | |
| | temporal_freqs = lang_frequencies(head_dim // 4) |
| | pos_t = torch.arange(T).float() |
| | freqs_t = torch.outer(pos_t, temporal_freqs) |
| | freqs_t = freqs_t.repeat_interleave(2, dim=-1) |
| | freqs_t = eo.repeat(freqs_t, "t d -> (t h w) d", h=H, w=W) |
| |
|
| | |
| | return torch.cat([freqs_x, freqs_y, freqs_t], dim=-1) |
| |
|
| | def get_angles(self, pos_ids): |
| | """Look up cos/sin angles for given position IDs.""" |
| | t, y, x = pos_ids["t_pos"], pos_ids["y_pos"], pos_ids["x_pos"] |
| | H, W = self.config.height, self.config.width |
| | if not torch.compiler.is_compiling(): |
| | torch._assert( |
| | (y.max() < H) & (x.max() < W), |
| | f"pos_ids out of bounds, {y.max()}, {x.max()}", |
| | ) |
| | flat = t * (H * W) + y * W + x |
| | idx = flat.reshape(-1).to(torch.long) |
| | cos = self.cos.index_select(0, idx).view(*flat.shape, -1) |
| | sin = self.sin.index_select(0, idx).view(*flat.shape, -1) |
| | return cos[:, None], sin[:, None] |
| |
|
| | @torch.autocast("cuda", enabled=False) |
| | def forward(self, x, pos_ids): |
| | assert self.cos.dtype == self.sin.dtype == torch.float32 |
| | cos, sin = self.get_angles(pos_ids) |
| | x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1) |
| | y0 = x0 * cos - x1 * sin |
| | y1 = x1 * cos + x0 * sin |
| | return torch.cat((y0, y1), dim=-1).type_as(x) |
| |
|
| |
|
| | class Attn(nn.Module): |
| | """Self-attention with RoPE and optional GQA, value residual, and gated attention.""" |
| |
|
| | def __init__(self, config, layer_idx): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| |
|
| | self.value_residual = getattr(config, "value_residual", False) |
| | if self.value_residual: |
| | self.v_lamb = nn.Parameter(torch.tensor(0.5)) |
| |
|
| | self.n_heads = config.n_heads |
| | self.n_kv_heads = getattr(config, "n_kv_heads", config.n_heads) |
| | self.d_head = config.d_model // self.n_heads |
| | assert config.d_model % self.n_heads == 0 |
| |
|
| | self.enable_gqa = self.n_heads != self.n_kv_heads |
| |
|
| | self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False) |
| | self.k_proj = nn.Linear( |
| | config.d_model, self.n_kv_heads * self.d_head, bias=False |
| | ) |
| | self.v_proj = nn.Linear( |
| | config.d_model, self.n_kv_heads * self.d_head, bias=False |
| | ) |
| | self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
| |
|
| | self.rope = OrthoRoPE(config) |
| |
|
| | self.gated_attn = getattr(config, "gated_attn", False) |
| | if self.gated_attn: |
| | self.gate_proj = nn.Linear( |
| | self.n_heads, self.n_heads, bias=False |
| | ) |
| | nn.init.zeros_(self.gate_proj.weight) |
| |
|
| | def forward(self, x, pos_ids, v1, kv_cache): |
| | |
| | q = eo.rearrange( |
| | self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head |
| | ) |
| | k = eo.rearrange( |
| | self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head |
| | ) |
| | v = eo.rearrange( |
| | self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head |
| | ) |
| |
|
| | if self.value_residual: |
| | v1 = v if v1 is None else v1 |
| | v = torch.lerp(v, v1.view_as(v), self.v_lamb) |
| |
|
| | q, k = rms_norm(q), rms_norm(k) |
| | q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) |
| |
|
| | k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) |
| | y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) |
| |
|
| | if self.gated_attn: |
| | gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) |
| | y = y * gates.permute(0, 2, 1).unsqueeze(-1) |
| | y = eo.rearrange(y, "b h t d -> b t (h d)") |
| | y = self.out_proj(y) |
| | return y, v1 |
| |
|
| |
|
| | class MergedQKVAttn(Attn): |
| | def __init__(self, src: Attn, config): |
| | super().__init__(config, src.layer_idx) |
| | self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype) |
| | self.load_state_dict( |
| | src.state_dict(), strict=False |
| | ) |
| | self.train(src.training) |
| |
|
| | self.q_out = self.n_heads * self.d_head |
| | self.kv_out = self.n_kv_heads * self.d_head |
| |
|
| | self.qkv_proj = nn.Linear( |
| | self.q_proj.in_features, |
| | self.q_out + 2 * self.kv_out, |
| | bias=False, |
| | device=self.q_proj.weight.device, |
| | dtype=self.q_proj.weight.dtype, |
| | ) |
| | with torch.no_grad(): |
| | self.qkv_proj.weight.copy_( |
| | torch.cat( |
| | [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0 |
| | ) |
| | ) |
| |
|
| | del self.q_proj, self.k_proj, self.v_proj |
| |
|
| | def forward(self, x, pos_ids, v1, kv_cache): |
| | q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1) |
| |
|
| | B, T = x.shape[:2] |
| | q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2) |
| | k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) |
| | v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2) |
| |
|
| | if self.value_residual: |
| | v1 = v if v1 is None else v1 |
| | v = torch.lerp(v, v1.view_as(v), self.v_lamb) |
| |
|
| | q, k = rms_norm(q), rms_norm(k) |
| | q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) |
| |
|
| | k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) |
| | y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) |
| |
|
| | if self.gated_attn: |
| | gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) |
| | y = y * gates.permute(0, 2, 1).unsqueeze(-1) |
| |
|
| | y = y.transpose(1, 2).reshape(B, T, -1) |
| | y = self.out_proj(y) |
| | return y, v1 |
| |
|
| |
|
| | class CrossAttention(nn.Module): |
| | """Cross-attention for prompt conditioning.""" |
| |
|
| | def __init__(self, config, context_dim=None): |
| | super().__init__() |
| | assert config.d_model % config.n_heads == 0 |
| |
|
| | self.d_head = config.d_model // config.n_heads |
| | self.inner_dim = context_dim or config.d_model |
| | assert self.inner_dim % self.d_head == 0 |
| | self.n_heads = self.inner_dim // self.d_head |
| | self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False) |
| | self.k_proj = nn.Linear( |
| | context_dim or config.d_model, self.inner_dim, bias=False |
| | ) |
| | self.v_proj = nn.Linear( |
| | context_dim or config.d_model, self.inner_dim, bias=False |
| | ) |
| |
|
| | self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False) |
| | self.out_proj.weight.detach().zero_() |
| |
|
| | def forward(self, x, context, context_pad_mask=None): |
| | q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads) |
| | k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads) |
| | v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads) |
| | q, k = rms_norm(q), rms_norm(k) |
| | out = flex_attention(q, k, v) |
| | out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1) |
| | return self.out_proj(out) |
| |
|