NariLabs commited on
Commit
4bc3602
·
verified ·
1 Parent(s): 8ab346a

Delete context.py

Browse files
Files changed (1) hide show
  1. context.py +0 -138
context.py DELETED
@@ -1,138 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
- from typing import Optional
6
- import warnings
7
-
8
- import torch
9
- from safetensors.torch import load_file
10
- from transformers import AutoTokenizer, PreTrainedTokenizerBase
11
-
12
- from ..config import DiaConfig, load_config
13
- from ..core.model import Dia2Model
14
- from ..core.precision import Precision, resolve_precision
15
- from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
16
- from .state_machine import StateMachine, TokenIds
17
-
18
-
19
- @dataclass
20
- class RuntimeContext:
21
- config: DiaConfig
22
- model: Dia2Model
23
- precision: Precision
24
- tokenizer: PreTrainedTokenizerBase
25
- mimi: MimiCodec
26
- device: torch.device
27
- machine: StateMachine
28
- transformer_step: callable
29
- depformer_step: callable
30
- constants: TokenIds
31
- audio_delays: list[int]
32
- audio_delay_tensor: torch.Tensor
33
- frame_rate: float
34
-
35
-
36
- def build_runtime(
37
- *,
38
- config_path: str | Path,
39
- weights_path: str | Path,
40
- tokenizer_id: Optional[str],
41
- repo_id: Optional[str],
42
- mimi_id: Optional[str],
43
- device: str,
44
- dtype_pref: str,
45
- ) -> tuple[RuntimeContext, str, str]:
46
- device_obj = torch.device(device)
47
- if device_obj.type == "cuda":
48
- cuda_matmul = torch.backends.cuda.matmul
49
- if hasattr(cuda_matmul, "fp32_precision"):
50
- cuda_matmul.fp32_precision = "tf32"
51
- with warnings.catch_warnings():
52
- warnings.filterwarnings(
53
- "ignore",
54
- message="Please use the new API settings",
55
- )
56
- torch.backends.cuda.matmul.allow_tf32 = True
57
- else: # pragma: no cover - compatibility with older PyTorch
58
- torch.backends.cuda.matmul.allow_tf32 = True
59
-
60
- # Handle cuDNN conv TF32 settings (check if conv attribute exists first)
61
- if hasattr(torch.backends.cudnn, "conv"):
62
- cudnn_conv = torch.backends.cudnn.conv
63
- if hasattr(cudnn_conv, "fp32_precision"):
64
- cudnn_conv.fp32_precision = "tf32"
65
- with warnings.catch_warnings():
66
- warnings.filterwarnings(
67
- "ignore",
68
- message="Please use the new API settings",
69
- )
70
- torch.backends.cudnn.allow_tf32 = True
71
- else:
72
- torch.backends.cudnn.allow_tf32 = True
73
- else:
74
- # For older PyTorch versions without the conv attribute
75
- torch.backends.cudnn.allow_tf32 = True
76
- precision = resolve_precision(dtype_pref, device_obj)
77
- config = load_config(config_path)
78
- model = Dia2Model(config, precision)
79
- state = load_file(str(weights_path))
80
- model.load_state_dict(state)
81
- model = model.to(device_obj)
82
-
83
- tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
84
- if tokenizer_ref is None:
85
- raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
86
- tokenizer = AutoTokenizer.from_pretrained(
87
- tokenizer_ref,
88
- use_fast=False,
89
- trust_remote_code=True,
90
- )
91
-
92
- mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
93
- mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
94
-
95
- data_cfg = config.data
96
- constants = TokenIds(
97
- card=data_cfg.text_vocab_size,
98
- new_word=data_cfg.text_new_word_token_id,
99
- pad=data_cfg.text_pad_token_id,
100
- bos=getattr(tokenizer, "bos_token_id", 1) or 1,
101
- zero=data_cfg.text_zero_token_id,
102
- spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
103
- spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
104
- audio_pad=data_cfg.audio_pad_token_id,
105
- audio_bos=data_cfg.audio_bos_token_id,
106
- )
107
- machine = StateMachine(
108
- token_ids=constants,
109
- second_stream_ahead=data_cfg.second_stream_ahead,
110
- max_padding=6,
111
- initial_padding=0,
112
- )
113
- audio_delays = list(data_cfg.delay_pattern)
114
- audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
115
- frame_rate = getattr(mimi, "frame_rate", 75.0)
116
-
117
- runtime = RuntimeContext(
118
- config=config,
119
- precision=precision,
120
- model=model,
121
- tokenizer=tokenizer,
122
- mimi=mimi,
123
- device=device_obj,
124
- machine=machine,
125
- constants=constants,
126
- audio_delays=audio_delays,
127
- audio_delay_tensor=audio_delay_tensor,
128
- frame_rate=frame_rate,
129
- transformer_step=model.transformer.forward_step,
130
- depformer_step=model.depformer.forward_step,
131
- )
132
- return runtime, tokenizer_ref, mimi_ref
133
-
134
-
135
- __all__ = [
136
- "RuntimeContext",
137
- "build_runtime",
138
- ]