File size: 2,296 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import annotations

from dataclasses import dataclass

import torch
from torch import nn

from ..config import DiaConfig
from .cache import KVCache
from .depformer import Depformer
from .precision import Precision
from .transformer import TransformerDecoder


@dataclass
class DecodeState:
    transformer: KVCache
    depformer: KVCache


class Dia2Model(nn.Module):
    def __init__(self, config: DiaConfig, precision: Precision):
        super().__init__()
        self.config = config
        self.precision = precision
        self.transformer = TransformerDecoder(config, precision)
        self.depformer = Depformer(config, precision)
        self._cast_norms_to_compute()

    def init_state(self, batch_size: int, device: torch.device, max_steps: int) -> DecodeState:
        transformer_cache = self.transformer.init_cache(batch_size, device, max_steps)
        depformer_cache = self.depformer.init_cache(batch_size, device, self.depformer.num_depth)
        return DecodeState(transformer_cache, depformer_cache)

    def step_text(
        self,
        tokens: torch.Tensor,
        positions: torch.Tensor,
        state: DecodeState,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        hidden, action, cb0, cache = self.transformer.forward_step(tokens, positions, state.transformer)
        state.transformer = cache
        return hidden, action, cb0

    def step_audio_stage(
        self,
        stage_index: int,
        prev_audio: torch.Tensor,
        transformer_hidden: torch.Tensor,
        state: DecodeState,
        main_text: Optional[torch.Tensor],
        second_text: Optional[torch.Tensor],
    ) -> torch.Tensor:
        cache = state.depformer
        logits, new_cache = self.depformer.forward_step(
            prev_audio,
            transformer_hidden,
            stage_index,
            cache,
            main_text,
            second_text,
        )
        state.depformer = new_cache
        return logits

    def _cast_norms_to_compute(self) -> None:
        """Cast RMSNorm weights/biases to the compute dtype to avoid bf16 warnings."""
        def _convert(module: nn.Module) -> None:
            if isinstance(module, nn.RMSNorm):
                module.to(self.precision.compute)

        self.apply(_convert)