Dia2-2B / core /transformer.py
NariLabs's picture
Upload folder using huggingface_hub
aa16b75 verified
raw
history blame
4.95 kB
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from ..config import DiaConfig
from .cache import KVCache
from .precision import Precision
from .layers import (
AttentionShape,
MultiStreamEmbedding,
Mlp,
Attention,
)
class TransformerDecoder(nn.Module):
"""Inference-time port of dia_v2.model.Transformer."""
def __init__(self, config: DiaConfig, precision: Precision):
super().__init__()
self.config = config
self.precision = precision
data_cfg = config.data
dec_cfg = config.model.decoder
self.audio_embeds = nn.ModuleList(
[
nn.Embedding(
data_cfg.audio_vocab_size,
dec_cfg.n_embd,
)
for _ in range(max(0, data_cfg.channels - 2))
]
)
self.text_embed = MultiStreamEmbedding(
data_cfg.text_vocab_size,
dec_cfg.n_embd,
pad_id=data_cfg.text_pad_token_id,
output_dtype=self.precision.compute,
low_rank_dim=dec_cfg.low_rank_dim,
)
self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)])
self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32)
self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False)
self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
heads = self.layers[0].attn.num_kv_heads
head_dim = self.layers[0].attn.head_dim
return KVCache.allocate(
num_layers=len(self.layers),
batch_size=batch_size,
heads=heads,
max_steps=max_steps,
head_dim=head_dim,
device=device,
dtype=self.precision.compute,
)
def forward_step(
self,
tokens: torch.Tensor,
positions: torch.Tensor,
cache: KVCache,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]:
if cache is None:
raise ValueError("Transformer cache must be initialized")
B, C, T1 = tokens.shape
if T1 != 1:
raise ValueError("forward_step expects sequence length 1")
num_audio_channels = max(0, C - 2)
hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
for idx in range(num_audio_channels):
audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
hidden_t.add_(audio_emb)
hidden_t = hidden_t.to(self.precision.compute)
x = hidden_t
for idx, layer in enumerate(self.layers):
slot = cache.get_slot(idx)
x, _ = layer.decode_step(x, positions, slot)
hidden_norm = self.norm(x)
action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
return hidden_norm, action_logits, cb0_logits, cache
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
B, C, T1 = tokens.shape
if T1 != 1:
raise ValueError("_embed expects sequence length 1")
num_audio_channels = max(0, C - 2)
text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
audio_terms: list[torch.Tensor] = []
for idx in range(num_audio_channels):
audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
audio_terms.append(audio_emb)
hidden = text_hidden
for term in audio_terms:
hidden = hidden + term
final = hidden.to(self.precision.compute)
return final
class DecoderLayer(nn.Module):
def __init__(self, config: DiaConfig, precision: Precision):
super().__init__()
dec = config.model.decoder
eps = config.model.normalization_layer_epsilon
self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
self.attn = Attention(config, dec.n_embd, precision.compute)
self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
self.mlp = Mlp(
dec.n_embd,
dec.n_hidden,
precision.compute,
tuple(config.model.linear.mlp_activations),
)
def decode_step(
self,
x: torch.Tensor,
pos: torch.Tensor,
cache_slot,
) -> Tuple[torch.Tensor, object]:
residual = x
x_norm = self.pre_norm(x)
attn_out, _ = self.attn(x_norm, pos, cache_slot)
x = residual + attn_out
residual2 = x
x_norm2 = self.post_norm(x)
mlp_out = self.mlp(x_norm2)
return residual2 + mlp_out, cache_slot