File size: 4,863 Bytes
04d9952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import warnings

import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from ..config import DiaConfig, load_config
from ..core.model import Dia2Model
from ..core.precision import Precision, resolve_precision
from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
from .state_machine import StateMachine, TokenIds


@dataclass
class RuntimeContext:
    config: DiaConfig
    model: Dia2Model
    precision: Precision
    tokenizer: PreTrainedTokenizerBase
    mimi: MimiCodec
    device: torch.device
    machine: StateMachine
    transformer_step: callable
    depformer_step: callable
    constants: TokenIds
    audio_delays: list[int]
    audio_delay_tensor: torch.Tensor
    frame_rate: float


def build_runtime(
    *,
    config_path: str | Path,
    weights_path: str | Path,
    tokenizer_id: Optional[str],
    repo_id: Optional[str],
    mimi_id: Optional[str],
    device: str,
    dtype_pref: str,
) -> tuple[RuntimeContext, str, str]:
    device_obj = torch.device(device)
    if device_obj.type == "cuda":
        cuda_matmul = torch.backends.cuda.matmul
        if hasattr(cuda_matmul, "fp32_precision"):
            cuda_matmul.fp32_precision = "tf32"
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    message="Please use the new API settings",
                )
                torch.backends.cuda.matmul.allow_tf32 = True
        else:  # pragma: no cover - compatibility with older PyTorch
            torch.backends.cuda.matmul.allow_tf32 = True

        # Handle cuDNN conv TF32 settings (check if conv attribute exists first)
        if hasattr(torch.backends.cudnn, "conv"):
            cudnn_conv = torch.backends.cudnn.conv
            if hasattr(cudnn_conv, "fp32_precision"):
                cudnn_conv.fp32_precision = "tf32"
                with warnings.catch_warnings():
                    warnings.filterwarnings(
                        "ignore",
                        message="Please use the new API settings",
                    )
                    torch.backends.cudnn.allow_tf32 = True
            else:
                torch.backends.cudnn.allow_tf32 = True
        else:
            # For older PyTorch versions without the conv attribute
            torch.backends.cudnn.allow_tf32 = True
    precision = resolve_precision(dtype_pref, device_obj)
    config = load_config(config_path)
    model = Dia2Model(config, precision)
    state = load_file(str(weights_path))
    model.load_state_dict(state)
    model = model.to(device_obj)

    tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
    if tokenizer_ref is None:
        raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_ref,
        use_fast=False,
        trust_remote_code=True,
    )

    mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
    mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)

    data_cfg = config.data
    constants = TokenIds(
        card=data_cfg.text_vocab_size,
        new_word=data_cfg.text_new_word_token_id,
        pad=data_cfg.text_pad_token_id,
        bos=getattr(tokenizer, "bos_token_id", 1) or 1,
        zero=data_cfg.text_zero_token_id,
        spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
        spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
        audio_pad=data_cfg.audio_pad_token_id,
        audio_bos=data_cfg.audio_bos_token_id,
    )
    machine = StateMachine(
        token_ids=constants,
        second_stream_ahead=data_cfg.second_stream_ahead,
        max_padding=6,
        initial_padding=0,
    )
    audio_delays = list(data_cfg.delay_pattern)
    audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
    frame_rate = getattr(mimi, "frame_rate", 75.0)

    runtime = RuntimeContext(
        config=config,
        precision=precision,
        model=model,
        tokenizer=tokenizer,
        mimi=mimi,
        device=device_obj,
        machine=machine,
        constants=constants,
        audio_delays=audio_delays,
        audio_delay_tensor=audio_delay_tensor,
        frame_rate=frame_rate,
        transformer_step=model.transformer.forward_step,
        depformer_step=model.depformer.forward_step,
    )
    return runtime, tokenizer_ref, mimi_ref


__all__ = [
    "RuntimeContext",
    "build_runtime",
]