from __future__ import annotations import math from dataclasses import dataclass from typing import Optional, Tuple, Union, List import torch from torch import nn import torch.nn.functional as F class RotaryEmbedding(nn.Module): def __init__(self, head_dim: int, min_timescale: int, max_timescale: int): super().__init__() if head_dim % 2 != 0: raise ValueError("RoPE dimension must be even") half_dim = head_dim // 2 fraction = (2.0 * torch.arange(0, half_dim)) / head_dim timescale = min_timescale * (max_timescale / min_timescale) ** fraction inv_freq = 1.0 / timescale self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False) def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: pos = position_ids.to(self.inv_freq.dtype) freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) while emb.dim() < x.dim(): emb = emb.unsqueeze(-2) cos = emb.cos().to(x.dtype) sin = emb.sin().to(x.dtype) x1, x2 = torch.chunk(x, 2, dim=-1) rotated = torch.cat((-x2, x1), dim=-1) return (x * cos) + (rotated * sin) def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., ::2] x2 = x[..., 1::2] return torch.stack((-x2, x1), dim=-1).reshape_as(x) def _get_activation(name: str) -> nn.Module: name = name.lower() if name in ("silu", "swish", "swiglu"): return nn.SiLU() if name in ("gelu", "geglu"): return nn.GELU() if name == "relu": return nn.ReLU() if name == "linear": return nn.Identity() raise ValueError(f"Unsupported activation {name}") @dataclass class AttentionShape: dim: int heads: int kv_heads: int head_dim: int rope_min: int rope_max: int apply_rope: bool class Attention(nn.Module): """Byte-for-byte port of dia_v2 Attention.forward_incremental.""" def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None: super().__init__() dec = config.model.decoder self.num_query_heads = dec.gqa_query_heads self.num_kv_heads = dec.kv_heads self.head_dim = dec.gqa_head_dim self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1) self.compute_dtype = compute_dtype self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False) eps = config.model.normalization_layer_epsilon self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32) self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32) self.rotary = RotaryEmbedding( self.head_dim, config.model.rope_min_timescale, config.model.rope_max_timescale, ) def forward_incremental( self, x: torch.Tensor, pos: Optional[torch.Tensor], cache_slot, ) -> Tuple[torch.Tensor, object]: B, T, _ = x.shape if T != 1: raise ValueError("Attention expects sequence length 1 during decoding") orig_dtype = x.dtype q_proj = self._project_heads(self.q_proj, x, self.num_query_heads) k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads) v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads) q_proj = self.q_norm(q_proj) k_proj = self.k_norm(k_proj) if pos is not None: q_proj = self.rotary(q_proj, pos) k_proj = self.rotary(k_proj, pos) q = q_proj.transpose(1, 2) k = k_proj.transpose(1, 2) v = v_proj.transpose(1, 2) if cache_slot is not None: k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v) else: k_cache, v_cache = k, v attn_mask = None attn = F.scaled_dot_product_attention( q, k_cache, v_cache, scale=1.0, attn_mask=attn_mask, enable_gqa=self.num_gqa_groups > 1, ) attn = attn.transpose(1, 2).contiguous() flat = attn.reshape(B, T, self.num_query_heads * self.head_dim) out = self.o_proj(flat.to(torch.float32)) return out.to(orig_dtype), cache_slot def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor: proj = layer(x.to(torch.float32)) B, T, _ = proj.shape proj = proj.view(B, T, heads, self.head_dim) return proj.to(self.compute_dtype) def forward( self, x: torch.Tensor, positions: Optional[torch.Tensor], cache=None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: return self.forward_incremental(x, positions, cache) class MultiStreamEmbedding(nn.Module): """Port of dia_v2 MultiStreamEmbed.""" def __init__( self, vocab_size: int, dim: int, pad_id: int, *, output_dtype: torch.dtype, low_rank_dim: Optional[int] = None, ) -> None: super().__init__() self.pad_id = pad_id self.dtype = output_dtype base_dim = low_rank_dim if low_rank_dim is not None else dim self.embedding = nn.Embedding(vocab_size, base_dim) self.main_proj = nn.Linear(base_dim, dim, bias=False) self.second_proj = nn.Linear(base_dim, dim, bias=False) def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor: main_inputs = main_inputs.long() second_inputs = second_inputs.long() if self.pad_id is not None: second_is_pad = second_inputs == self.pad_id else: second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool) use_second = ~second_is_pad emb_main = self.embedding(main_inputs) emb_second = self.embedding(second_inputs) out_main = self.main_proj(emb_main.to(torch.float32)) out_second = self.second_proj(emb_second.to(torch.float32)) zeros = torch.zeros_like(out_second) y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros) target_dtype = self.dtype if self.dtype is not None else y.dtype return y.to(target_dtype) class Mlp(nn.Module): """Port of dia_v2 MlpBlock (two-activation gated MLP).""" def __init__( self, dim: int, hidden: int, compute_dtype: torch.dtype, activations: Sequence[str], ) -> None: super().__init__() if len(activations) != 2: raise ValueError("Mlp expects two activation functions.") self.dtype = compute_dtype self.hidden = hidden self.branch_count = len(activations) self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False) self.wo = nn.Linear(hidden, dim, bias=False) self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])] def forward(self, x: torch.Tensor) -> torch.Tensor: proj = self.wi(x.to(torch.float32)) proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype) gate, up = proj.unbind(dim=-2) hidden = self.activation_fns[0](gate) * self.activation_fns[1](up) out = self.wo(hidden.to(torch.float32)) return out.to(self.dtype)