Dia2-2B / dia2 /core /depformer.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
9.74 kB
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from ..config import DiaConfig
from .cache import KVCache
from .layers import MultiStreamEmbedding, Mlp, RotaryEmbedding
from .precision import Precision
class ScheduleAttention(nn.Module):
"""Depformer attention that mirrors dia_v2 ScheduleAttention."""
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype) -> None:
super().__init__()
dep_cfg = config.model.depformer
runtime = config.runtime
self.schedule = runtime.weights_schedule
self.num_query_heads = dep_cfg.gqa_query_heads
self.num_kv_heads = dep_cfg.kv_heads
self.head_dim = dep_cfg.gqa_head_dim
self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
self.apply_rope = dep_cfg.apply_rope
self.used_ids = sorted(set(self.schedule))
self.compute_dtype = compute_dtype
self.in_proj = nn.ModuleDict(
{
str(i): nn.Linear(
dep_cfg.n_embd,
3 * self.num_query_heads * self.head_dim,
bias=False,
)
for i in self.used_ids
}
)
self.out_proj = nn.ModuleDict(
{
str(i): nn.Linear(
self.num_query_heads * self.head_dim,
dep_cfg.n_embd,
bias=False,
)
for i in self.used_ids
}
)
eps = config.model.normalization_layer_epsilon
self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
if self.apply_rope:
self.rotary = RotaryEmbedding(
self.head_dim,
config.model.rope_min_timescale,
config.model.rope_max_timescale,
)
stage_count = max(len(self.schedule), 1)
self.register_buffer(
"stage_positions",
torch.arange(stage_count, dtype=torch.long).view(stage_count, 1),
persistent=False,
)
else:
self.rotary = None
self.register_buffer(
"stage_positions",
torch.zeros(0, 1, dtype=torch.long),
persistent=False,
)
def forward_incremental(
self,
x_t: torch.Tensor,
stage_index: int,
cache_slot,
) -> Tuple[torch.Tensor, object]:
bsz, seq, _ = x_t.shape
if seq != 1:
raise ValueError("ScheduleAttention expects seq len 1 during decoding")
orig_dtype = x_t.dtype
module_index = self.schedule[stage_index]
proj = self.in_proj[str(module_index)](x_t.to(torch.float32))
proj = proj.view(bsz, seq, 3, self.num_query_heads, self.head_dim).to(self.compute_dtype)
q_proj = self.q_norm(proj[:, :, 0])
k_proj = self.k_norm(proj[:, :, 1])
v_proj = proj[:, :, 2]
if self.apply_rope:
pos_ids = self.stage_positions[stage_index : stage_index + 1]
if pos_ids.device != x_t.device:
pos_ids = pos_ids.to(x_t.device)
q_proj = self.rotary(q_proj, pos_ids)
k_proj = self.rotary(k_proj, pos_ids)
q = q_proj.transpose(1, 2)
k = k_proj.transpose(1, 2)
v = v_proj.transpose(1, 2)
if cache_slot is not None:
k, v, attn_mask = cache_slot.write_and_view(k, v)
else:
attn_mask = None
attn = F.scaled_dot_product_attention(
q,
k,
v,
scale=1.0,
attn_mask=attn_mask,
enable_gqa=self.num_gqa_groups > 1,
)
attn = attn.transpose(1, 2).contiguous()
flat = attn.reshape(bsz, seq, self.num_query_heads * self.head_dim)
out = self.out_proj[str(module_index)](flat.to(torch.float32))
return out.to(orig_dtype), cache_slot
class DepformerLayer(nn.Module):
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
super().__init__()
dep_cfg = config.model.depformer
eps = config.model.normalization_layer_epsilon
self.pre_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
self.post_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
self.self_attention = ScheduleAttention(config, compute_dtype)
self.mlp = Mlp(
dep_cfg.n_embd,
dep_cfg.n_hidden,
compute_dtype,
tuple(config.model.depformer.mlp_activations),
)
def decode_step(
self,
x_t: torch.Tensor,
stage_index: int,
cache_slot,
) -> Tuple[torch.Tensor, object]:
residual = x_t
x_norm = self.pre_norm(x_t)
sa_out, _ = self.self_attention.forward_incremental(x_norm, stage_index, cache_slot)
x = residual + sa_out
residual2 = x
x_norm2 = self.post_norm(x)
mlp_out = self.mlp(x_norm2)
return residual2 + mlp_out, cache_slot
class Depformer(nn.Module):
def __init__(self, config: DiaConfig, precision: Precision):
super().__init__()
self.config = config
self.precision = precision
dep_cfg = config.model.depformer
data_cfg = config.data
runtime = config.runtime
self.num_audio_channels = max(0, data_cfg.channels - 2)
self.num_depth = max(self.num_audio_channels - 1, 0)
self.weights_schedule = runtime.weights_schedule
self.audio_embeds = nn.ModuleList(
[nn.Embedding(data_cfg.audio_vocab_size, dep_cfg.n_embd) for _ in range(self.num_depth)]
)
if dep_cfg.text_embedding:
self.text_embed = MultiStreamEmbedding(
data_cfg.text_vocab_size,
dep_cfg.n_embd,
pad_id=data_cfg.text_pad_token_id,
output_dtype=precision.compute,
)
else:
self.text_embed = None
used_ids = sorted(set(self.weights_schedule))
self.depformer_in = nn.ModuleDict(
{
str(i): nn.Linear(
config.model.decoder.n_embd,
dep_cfg.n_embd,
bias=False,
)
for i in used_ids
}
)
self.layers = nn.ModuleList([DepformerLayer(config, precision.compute) for _ in range(dep_cfg.n_layer)])
self.norm = nn.RMSNorm(dep_cfg.n_embd, eps=config.model.normalization_layer_epsilon)
self.logits_dtype = precision.logits
self.logits = nn.ModuleList(
[
nn.Linear(dep_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
for _ in range(self.num_depth)
]
)
self.audio_vocab_limit = min(data_cfg.audio_pad_token_id, data_cfg.audio_bos_token_id)
def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
heads = self.layers[0].self_attention.num_kv_heads
head_dim = self.layers[0].self_attention.head_dim
return KVCache.allocate(
num_layers=len(self.layers),
batch_size=batch_size,
heads=heads,
max_steps=max_steps,
head_dim=head_dim,
device=device,
dtype=self.precision.compute,
)
def forward_step(
self,
prev_audio: torch.Tensor,
transformer_out: torch.Tensor,
stage_index: int,
cache: KVCache,
main_text: Optional[torch.Tensor],
second_text: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, KVCache]:
self._validate_inputs(stage_index, cache)
return self._forward_stage(stage_index, prev_audio, transformer_out, cache, main_text, second_text)
def _forward_stage(
self,
stage_index: int,
prev_audio: torch.Tensor,
transformer_out: torch.Tensor,
cache: KVCache,
main_text: Optional[torch.Tensor],
second_text: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, KVCache]:
prev_audio = prev_audio.long()
weight_idx = self.weights_schedule[stage_index]
token_emb = self.audio_embeds[stage_index](prev_audio[:, None]).to(self.precision.compute)
if stage_index == 0 and self.text_embed is not None:
if main_text is None or second_text is None:
raise ValueError("stage 0 requires text tokens")
token_emb = token_emb + self.text_embed(main_text[:, None], second_text[:, None])
dep_in = self.depformer_in[str(weight_idx)](transformer_out.to(torch.float32))
dep_in = dep_in.to(self.precision.compute)
dep_in = dep_in + token_emb.to(dep_in.dtype)
x = dep_in
for idx, layer in enumerate(self.layers):
slot = cache.get_slot(idx)
x, _ = layer.decode_step(x, stage_index, slot)
hidden = self.norm(x)
logits = self.logits[stage_index](hidden.to(torch.float32))
logits = logits.to(self.logits_dtype)
logits = logits.unsqueeze(1)
logits = logits[..., : self.audio_vocab_limit]
return logits, cache
def _validate_inputs(self, stage_index: int, cache: KVCache | None) -> None:
if stage_index < 0 or stage_index >= self.num_depth:
raise ValueError(f"stage_index {stage_index} out of range (depth={self.num_depth})")
if cache is None:
raise ValueError("depformer cache must be initialized")