NariLabs commited on
Commit
04d9952
·
verified ·
1 Parent(s): aa16b75

Upload context.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. context.py +138 -0
context.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import warnings
7
+
8
+ import torch
9
+ from safetensors.torch import load_file
10
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
11
+
12
+ from ..config import DiaConfig, load_config
13
+ from ..core.model import Dia2Model
14
+ from ..core.precision import Precision, resolve_precision
15
+ from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
16
+ from .state_machine import StateMachine, TokenIds
17
+
18
+
19
+ @dataclass
20
+ class RuntimeContext:
21
+ config: DiaConfig
22
+ model: Dia2Model
23
+ precision: Precision
24
+ tokenizer: PreTrainedTokenizerBase
25
+ mimi: MimiCodec
26
+ device: torch.device
27
+ machine: StateMachine
28
+ transformer_step: callable
29
+ depformer_step: callable
30
+ constants: TokenIds
31
+ audio_delays: list[int]
32
+ audio_delay_tensor: torch.Tensor
33
+ frame_rate: float
34
+
35
+
36
+ def build_runtime(
37
+ *,
38
+ config_path: str | Path,
39
+ weights_path: str | Path,
40
+ tokenizer_id: Optional[str],
41
+ repo_id: Optional[str],
42
+ mimi_id: Optional[str],
43
+ device: str,
44
+ dtype_pref: str,
45
+ ) -> tuple[RuntimeContext, str, str]:
46
+ device_obj = torch.device(device)
47
+ if device_obj.type == "cuda":
48
+ cuda_matmul = torch.backends.cuda.matmul
49
+ if hasattr(cuda_matmul, "fp32_precision"):
50
+ cuda_matmul.fp32_precision = "tf32"
51
+ with warnings.catch_warnings():
52
+ warnings.filterwarnings(
53
+ "ignore",
54
+ message="Please use the new API settings",
55
+ )
56
+ torch.backends.cuda.matmul.allow_tf32 = True
57
+ else: # pragma: no cover - compatibility with older PyTorch
58
+ torch.backends.cuda.matmul.allow_tf32 = True
59
+
60
+ # Handle cuDNN conv TF32 settings (check if conv attribute exists first)
61
+ if hasattr(torch.backends.cudnn, "conv"):
62
+ cudnn_conv = torch.backends.cudnn.conv
63
+ if hasattr(cudnn_conv, "fp32_precision"):
64
+ cudnn_conv.fp32_precision = "tf32"
65
+ with warnings.catch_warnings():
66
+ warnings.filterwarnings(
67
+ "ignore",
68
+ message="Please use the new API settings",
69
+ )
70
+ torch.backends.cudnn.allow_tf32 = True
71
+ else:
72
+ torch.backends.cudnn.allow_tf32 = True
73
+ else:
74
+ # For older PyTorch versions without the conv attribute
75
+ torch.backends.cudnn.allow_tf32 = True
76
+ precision = resolve_precision(dtype_pref, device_obj)
77
+ config = load_config(config_path)
78
+ model = Dia2Model(config, precision)
79
+ state = load_file(str(weights_path))
80
+ model.load_state_dict(state)
81
+ model = model.to(device_obj)
82
+
83
+ tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
84
+ if tokenizer_ref is None:
85
+ raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ tokenizer_ref,
88
+ use_fast=False,
89
+ trust_remote_code=True,
90
+ )
91
+
92
+ mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
93
+ mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
94
+
95
+ data_cfg = config.data
96
+ constants = TokenIds(
97
+ card=data_cfg.text_vocab_size,
98
+ new_word=data_cfg.text_new_word_token_id,
99
+ pad=data_cfg.text_pad_token_id,
100
+ bos=getattr(tokenizer, "bos_token_id", 1) or 1,
101
+ zero=data_cfg.text_zero_token_id,
102
+ spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
103
+ spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
104
+ audio_pad=data_cfg.audio_pad_token_id,
105
+ audio_bos=data_cfg.audio_bos_token_id,
106
+ )
107
+ machine = StateMachine(
108
+ token_ids=constants,
109
+ second_stream_ahead=data_cfg.second_stream_ahead,
110
+ max_padding=6,
111
+ initial_padding=0,
112
+ )
113
+ audio_delays = list(data_cfg.delay_pattern)
114
+ audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
115
+ frame_rate = getattr(mimi, "frame_rate", 75.0)
116
+
117
+ runtime = RuntimeContext(
118
+ config=config,
119
+ precision=precision,
120
+ model=model,
121
+ tokenizer=tokenizer,
122
+ mimi=mimi,
123
+ device=device_obj,
124
+ machine=machine,
125
+ constants=constants,
126
+ audio_delays=audio_delays,
127
+ audio_delay_tensor=audio_delay_tensor,
128
+ frame_rate=frame_rate,
129
+ transformer_step=model.transformer.forward_step,
130
+ depformer_step=model.depformer.forward_step,
131
+ )
132
+ return runtime, tokenizer_ref, mimi_ref
133
+
134
+
135
+ __all__ = [
136
+ "RuntimeContext",
137
+ "build_runtime",
138
+ ]