Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Union, List | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, head_dim: int, min_timescale: int, max_timescale: int): | |
| super().__init__() | |
| if head_dim % 2 != 0: | |
| raise ValueError("RoPE dimension must be even") | |
| half_dim = head_dim // 2 | |
| fraction = (2.0 * torch.arange(0, half_dim)) / head_dim | |
| timescale = min_timescale * (max_timescale / min_timescale) ** fraction | |
| inv_freq = 1.0 / timescale | |
| self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False) | |
| def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: | |
| pos = position_ids.to(self.inv_freq.dtype) | |
| freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| while emb.dim() < x.dim(): | |
| emb = emb.unsqueeze(-2) | |
| cos = emb.cos().to(x.dtype) | |
| sin = emb.sin().to(x.dtype) | |
| x1, x2 = torch.chunk(x, 2, dim=-1) | |
| rotated = torch.cat((-x2, x1), dim=-1) | |
| return (x * cos) + (rotated * sin) | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1 = x[..., ::2] | |
| x2 = x[..., 1::2] | |
| return torch.stack((-x2, x1), dim=-1).reshape_as(x) | |
| def _get_activation(name: str) -> nn.Module: | |
| name = name.lower() | |
| if name in ("silu", "swish", "swiglu"): | |
| return nn.SiLU() | |
| if name in ("gelu", "geglu"): | |
| return nn.GELU() | |
| if name == "relu": | |
| return nn.ReLU() | |
| if name == "linear": | |
| return nn.Identity() | |
| raise ValueError(f"Unsupported activation {name}") | |
| class AttentionShape: | |
| dim: int | |
| heads: int | |
| kv_heads: int | |
| head_dim: int | |
| rope_min: int | |
| rope_max: int | |
| apply_rope: bool | |
| class Attention(nn.Module): | |
| """Byte-for-byte port of dia_v2 Attention.forward_incremental.""" | |
| def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None: | |
| super().__init__() | |
| dec = config.model.decoder | |
| self.num_query_heads = dec.gqa_query_heads | |
| self.num_kv_heads = dec.kv_heads | |
| self.head_dim = dec.gqa_head_dim | |
| self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1) | |
| self.compute_dtype = compute_dtype | |
| self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False) | |
| eps = config.model.normalization_layer_epsilon | |
| self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32) | |
| self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32) | |
| self.rotary = RotaryEmbedding( | |
| self.head_dim, | |
| config.model.rope_min_timescale, | |
| config.model.rope_max_timescale, | |
| ) | |
| def forward_incremental( | |
| self, | |
| x: torch.Tensor, | |
| pos: Optional[torch.Tensor], | |
| cache_slot, | |
| ) -> Tuple[torch.Tensor, object]: | |
| B, T, _ = x.shape | |
| if T != 1: | |
| raise ValueError("Attention expects sequence length 1 during decoding") | |
| orig_dtype = x.dtype | |
| q_proj = self._project_heads(self.q_proj, x, self.num_query_heads) | |
| k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads) | |
| v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads) | |
| q_proj = self.q_norm(q_proj) | |
| k_proj = self.k_norm(k_proj) | |
| if pos is not None: | |
| q_proj = self.rotary(q_proj, pos) | |
| k_proj = self.rotary(k_proj, pos) | |
| q = q_proj.transpose(1, 2) | |
| k = k_proj.transpose(1, 2) | |
| v = v_proj.transpose(1, 2) | |
| if cache_slot is not None: | |
| k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v) | |
| else: | |
| k_cache, v_cache = k, v | |
| attn_mask = None | |
| attn = F.scaled_dot_product_attention( | |
| q, | |
| k_cache, | |
| v_cache, | |
| scale=1.0, | |
| attn_mask=attn_mask, | |
| enable_gqa=self.num_gqa_groups > 1, | |
| ) | |
| attn = attn.transpose(1, 2).contiguous() | |
| flat = attn.reshape(B, T, self.num_query_heads * self.head_dim) | |
| out = self.o_proj(flat.to(torch.float32)) | |
| return out.to(orig_dtype), cache_slot | |
| def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor: | |
| proj = layer(x.to(torch.float32)) | |
| B, T, _ = proj.shape | |
| proj = proj.view(B, T, heads, self.head_dim) | |
| return proj.to(self.compute_dtype) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| positions: Optional[torch.Tensor], | |
| cache=None, | |
| ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| return self.forward_incremental(x, positions, cache) | |
| class MultiStreamEmbedding(nn.Module): | |
| """Port of dia_v2 MultiStreamEmbed.""" | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| dim: int, | |
| pad_id: int, | |
| *, | |
| output_dtype: torch.dtype, | |
| low_rank_dim: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.pad_id = pad_id | |
| self.dtype = output_dtype | |
| base_dim = low_rank_dim if low_rank_dim is not None else dim | |
| self.embedding = nn.Embedding(vocab_size, base_dim) | |
| self.main_proj = nn.Linear(base_dim, dim, bias=False) | |
| self.second_proj = nn.Linear(base_dim, dim, bias=False) | |
| def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor: | |
| main_inputs = main_inputs.long() | |
| second_inputs = second_inputs.long() | |
| if self.pad_id is not None: | |
| second_is_pad = second_inputs == self.pad_id | |
| else: | |
| second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool) | |
| use_second = ~second_is_pad | |
| emb_main = self.embedding(main_inputs) | |
| emb_second = self.embedding(second_inputs) | |
| out_main = self.main_proj(emb_main.to(torch.float32)) | |
| out_second = self.second_proj(emb_second.to(torch.float32)) | |
| zeros = torch.zeros_like(out_second) | |
| y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros) | |
| target_dtype = self.dtype if self.dtype is not None else y.dtype | |
| return y.to(target_dtype) | |
| class Mlp(nn.Module): | |
| """Port of dia_v2 MlpBlock (two-activation gated MLP).""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden: int, | |
| compute_dtype: torch.dtype, | |
| activations: Sequence[str], | |
| ) -> None: | |
| super().__init__() | |
| if len(activations) != 2: | |
| raise ValueError("Mlp expects two activation functions.") | |
| self.dtype = compute_dtype | |
| self.hidden = hidden | |
| self.branch_count = len(activations) | |
| self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False) | |
| self.wo = nn.Linear(hidden, dim, bias=False) | |
| self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| proj = self.wi(x.to(torch.float32)) | |
| proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype) | |
| gate, up = proj.unbind(dim=-2) | |
| hidden = self.activation_fns[0](gate) * self.activation_fns[1](up) | |
| out = self.wo(hidden.to(torch.float32)) | |
| return out.to(self.dtype) | |