File size: 4,950 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations

from typing import Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F

from ..config import DiaConfig
from .cache import KVCache
from .precision import Precision
from .layers import (
    AttentionShape,
    MultiStreamEmbedding,
    Mlp,
    Attention,
)


class TransformerDecoder(nn.Module):
    """Inference-time port of dia_v2.model.Transformer."""

    def __init__(self, config: DiaConfig, precision: Precision):
        super().__init__()
        self.config = config
        self.precision = precision
        data_cfg = config.data
        dec_cfg = config.model.decoder

        self.audio_embeds = nn.ModuleList(
            [
                nn.Embedding(
                    data_cfg.audio_vocab_size,
                    dec_cfg.n_embd,
                )
                for _ in range(max(0, data_cfg.channels - 2))
            ]
        )
        self.text_embed = MultiStreamEmbedding(
            data_cfg.text_vocab_size,
            dec_cfg.n_embd,
            pad_id=data_cfg.text_pad_token_id,
            output_dtype=self.precision.compute,
            low_rank_dim=dec_cfg.low_rank_dim,
        )
        self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)])
        self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32)

        self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False)
        self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)

    def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
        heads = self.layers[0].attn.num_kv_heads
        head_dim = self.layers[0].attn.head_dim
        return KVCache.allocate(
            num_layers=len(self.layers),
            batch_size=batch_size,
            heads=heads,
            max_steps=max_steps,
            head_dim=head_dim,
            device=device,
            dtype=self.precision.compute,
        )

    def forward_step(
        self,
        tokens: torch.Tensor,
        positions: torch.Tensor,
        cache: KVCache,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]:
        if cache is None:
            raise ValueError("Transformer cache must be initialized")

        B, C, T1 = tokens.shape
        if T1 != 1:
            raise ValueError("forward_step expects sequence length 1")
        num_audio_channels = max(0, C - 2)

        hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
        for idx in range(num_audio_channels):
            audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
            hidden_t.add_(audio_emb)
        hidden_t = hidden_t.to(self.precision.compute)

        x = hidden_t
        for idx, layer in enumerate(self.layers):
            slot = cache.get_slot(idx)
            x, _ = layer.decode_step(x, positions, slot)

        hidden_norm = self.norm(x)
        action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
        cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
        return hidden_norm, action_logits, cb0_logits, cache

    def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
        B, C, T1 = tokens.shape
        if T1 != 1:
            raise ValueError("_embed expects sequence length 1")
        num_audio_channels = max(0, C - 2)
        text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
        audio_terms: list[torch.Tensor] = []
        for idx in range(num_audio_channels):
            audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
            audio_terms.append(audio_emb)
        hidden = text_hidden
        for term in audio_terms:
            hidden = hidden + term
        final = hidden.to(self.precision.compute)
        return final


class DecoderLayer(nn.Module):
    def __init__(self, config: DiaConfig, precision: Precision):
        super().__init__()
        dec = config.model.decoder
        eps = config.model.normalization_layer_epsilon
        self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
        self.attn = Attention(config, dec.n_embd, precision.compute)
        self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
        self.mlp = Mlp(
            dec.n_embd,
            dec.n_hidden,
            precision.compute,
            tuple(config.model.linear.mlp_activations),
        )

    def decode_step(
        self,
        x: torch.Tensor,
        pos: torch.Tensor,
        cache_slot,
    ) -> Tuple[torch.Tensor, object]:
        residual = x
        x_norm = self.pre_norm(x)
        attn_out, _ = self.attn(x_norm, pos, cache_slot)
        x = residual + attn_out
        residual2 = x
        x_norm2 = self.post_norm(x)
        mlp_out = self.mlp(x_norm2)
        return residual2 + mlp_out, cache_slot