Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,702 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List
import torch
from torch import nn
import torch.nn.functional as F
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, min_timescale: int, max_timescale: int):
super().__init__()
if head_dim % 2 != 0:
raise ValueError("RoPE dimension must be even")
half_dim = head_dim // 2
fraction = (2.0 * torch.arange(0, half_dim)) / head_dim
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
inv_freq = 1.0 / timescale
self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False)
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
pos = position_ids.to(self.inv_freq.dtype)
freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
while emb.dim() < x.dim():
emb = emb.unsqueeze(-2)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
x1, x2 = torch.chunk(x, 2, dim=-1)
rotated = torch.cat((-x2, x1), dim=-1)
return (x * cos) + (rotated * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
def _get_activation(name: str) -> nn.Module:
name = name.lower()
if name in ("silu", "swish", "swiglu"):
return nn.SiLU()
if name in ("gelu", "geglu"):
return nn.GELU()
if name == "relu":
return nn.ReLU()
if name == "linear":
return nn.Identity()
raise ValueError(f"Unsupported activation {name}")
@dataclass
class AttentionShape:
dim: int
heads: int
kv_heads: int
head_dim: int
rope_min: int
rope_max: int
apply_rope: bool
class Attention(nn.Module):
"""Byte-for-byte port of dia_v2 Attention.forward_incremental."""
def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None:
super().__init__()
dec = config.model.decoder
self.num_query_heads = dec.gqa_query_heads
self.num_kv_heads = dec.kv_heads
self.head_dim = dec.gqa_head_dim
self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
self.compute_dtype = compute_dtype
self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False)
eps = config.model.normalization_layer_epsilon
self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
self.rotary = RotaryEmbedding(
self.head_dim,
config.model.rope_min_timescale,
config.model.rope_max_timescale,
)
def forward_incremental(
self,
x: torch.Tensor,
pos: Optional[torch.Tensor],
cache_slot,
) -> Tuple[torch.Tensor, object]:
B, T, _ = x.shape
if T != 1:
raise ValueError("Attention expects sequence length 1 during decoding")
orig_dtype = x.dtype
q_proj = self._project_heads(self.q_proj, x, self.num_query_heads)
k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads)
v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads)
q_proj = self.q_norm(q_proj)
k_proj = self.k_norm(k_proj)
if pos is not None:
q_proj = self.rotary(q_proj, pos)
k_proj = self.rotary(k_proj, pos)
q = q_proj.transpose(1, 2)
k = k_proj.transpose(1, 2)
v = v_proj.transpose(1, 2)
if cache_slot is not None:
k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v)
else:
k_cache, v_cache = k, v
attn_mask = None
attn = F.scaled_dot_product_attention(
q,
k_cache,
v_cache,
scale=1.0,
attn_mask=attn_mask,
enable_gqa=self.num_gqa_groups > 1,
)
attn = attn.transpose(1, 2).contiguous()
flat = attn.reshape(B, T, self.num_query_heads * self.head_dim)
out = self.o_proj(flat.to(torch.float32))
return out.to(orig_dtype), cache_slot
def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor:
proj = layer(x.to(torch.float32))
B, T, _ = proj.shape
proj = proj.view(B, T, heads, self.head_dim)
return proj.to(self.compute_dtype)
def forward(
self,
x: torch.Tensor,
positions: Optional[torch.Tensor],
cache=None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.forward_incremental(x, positions, cache)
class MultiStreamEmbedding(nn.Module):
"""Port of dia_v2 MultiStreamEmbed."""
def __init__(
self,
vocab_size: int,
dim: int,
pad_id: int,
*,
output_dtype: torch.dtype,
low_rank_dim: Optional[int] = None,
) -> None:
super().__init__()
self.pad_id = pad_id
self.dtype = output_dtype
base_dim = low_rank_dim if low_rank_dim is not None else dim
self.embedding = nn.Embedding(vocab_size, base_dim)
self.main_proj = nn.Linear(base_dim, dim, bias=False)
self.second_proj = nn.Linear(base_dim, dim, bias=False)
def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor:
main_inputs = main_inputs.long()
second_inputs = second_inputs.long()
if self.pad_id is not None:
second_is_pad = second_inputs == self.pad_id
else:
second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool)
use_second = ~second_is_pad
emb_main = self.embedding(main_inputs)
emb_second = self.embedding(second_inputs)
out_main = self.main_proj(emb_main.to(torch.float32))
out_second = self.second_proj(emb_second.to(torch.float32))
zeros = torch.zeros_like(out_second)
y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros)
target_dtype = self.dtype if self.dtype is not None else y.dtype
return y.to(target_dtype)
class Mlp(nn.Module):
"""Port of dia_v2 MlpBlock (two-activation gated MLP)."""
def __init__(
self,
dim: int,
hidden: int,
compute_dtype: torch.dtype,
activations: Sequence[str],
) -> None:
super().__init__()
if len(activations) != 2:
raise ValueError("Mlp expects two activation functions.")
self.dtype = compute_dtype
self.hidden = hidden
self.branch_count = len(activations)
self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False)
self.wo = nn.Linear(hidden, dim, bias=False)
self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])]
def forward(self, x: torch.Tensor) -> torch.Tensor:
proj = self.wi(x.to(torch.float32))
proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype)
gate, up = proj.unbind(dim=-2)
hidden = self.activation_fns[0](gate) * self.activation_fns[1](up)
out = self.wo(hidden.to(torch.float32))
return out.to(self.dtype)
|