File size: 4,574 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import warnings

import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from ..config import DiaConfig, load_config
from ..core.model import Dia2Model
from ..core.precision import Precision, resolve_precision
from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
from .state_machine import StateMachine, TokenIds


@dataclass
class RuntimeContext:
    config: DiaConfig
    model: Dia2Model
    precision: Precision
    tokenizer: PreTrainedTokenizerBase
    mimi: MimiCodec
    device: torch.device
    machine: StateMachine
    transformer_step: callable
    depformer_step: callable
    constants: TokenIds
    audio_delays: list[int]
    audio_delay_tensor: torch.Tensor
    frame_rate: float


def build_runtime(
    *,
    config_path: str | Path,
    weights_path: str | Path,
    tokenizer_id: Optional[str],
    repo_id: Optional[str],
    mimi_id: Optional[str],
    device: str,
    dtype_pref: str,
) -> tuple[RuntimeContext, str, str]:
    device_obj = torch.device(device)
    if device_obj.type == "cuda":
        cuda_matmul = torch.backends.cuda.matmul
        cudnn_conv = torch.backends.cudnn.conv
        if hasattr(cuda_matmul, "fp32_precision"):
            cuda_matmul.fp32_precision = "tf32"
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    message="Please use the new API settings",
                )
                torch.backends.cuda.matmul.allow_tf32 = True
        else:  # pragma: no cover - compatibility with older PyTorch
            torch.backends.cuda.matmul.allow_tf32 = True
        if hasattr(cudnn_conv, "fp32_precision"):
            cudnn_conv.fp32_precision = "tf32"
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    message="Please use the new API settings",
                )
                torch.backends.cudnn.allow_tf32 = True
        else:  # pragma: no cover
            torch.backends.cudnn.allow_tf32 = True
    precision = resolve_precision(dtype_pref, device_obj)
    config = load_config(config_path)
    model = Dia2Model(config, precision)
    state = load_file(str(weights_path))
    model.load_state_dict(state)
    model = model.to(device_obj)

    tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
    if tokenizer_ref is None:
        raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_ref,
        use_fast=False,
        trust_remote_code=True,
    )

    mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
    mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)

    data_cfg = config.data
    constants = TokenIds(
        card=data_cfg.text_vocab_size,
        new_word=data_cfg.text_new_word_token_id,
        pad=data_cfg.text_pad_token_id,
        bos=getattr(tokenizer, "bos_token_id", 1) or 1,
        zero=data_cfg.text_zero_token_id,
        spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
        spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
        audio_pad=data_cfg.audio_pad_token_id,
        audio_bos=data_cfg.audio_bos_token_id,
    )
    machine = StateMachine(
        token_ids=constants,
        second_stream_ahead=data_cfg.second_stream_ahead,
        max_padding=6,
        initial_padding=0,
    )
    audio_delays = list(data_cfg.delay_pattern)
    audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
    frame_rate = getattr(mimi, "frame_rate", 75.0)

    runtime = RuntimeContext(
        config=config,
        precision=precision,
        model=model,
        tokenizer=tokenizer,
        mimi=mimi,
        device=device_obj,
        machine=machine,
        constants=constants,
        audio_delays=audio_delays,
        audio_delay_tensor=audio_delay_tensor,
        frame_rate=frame_rate,
        transformer_step=model.transformer.forward_step,
        depformer_step=model.depformer.forward_step,
    )
    return runtime, tokenizer_ref, mimi_ref


__all__ = [
    "RuntimeContext",
    "build_runtime",
]