Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,863 Bytes
1315cad a087083 1315cad a087083 1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import warnings
import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from ..config import DiaConfig, load_config
from ..core.model import Dia2Model
from ..core.precision import Precision, resolve_precision
from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
from .state_machine import StateMachine, TokenIds
@dataclass
class RuntimeContext:
config: DiaConfig
model: Dia2Model
precision: Precision
tokenizer: PreTrainedTokenizerBase
mimi: MimiCodec
device: torch.device
machine: StateMachine
transformer_step: callable
depformer_step: callable
constants: TokenIds
audio_delays: list[int]
audio_delay_tensor: torch.Tensor
frame_rate: float
def build_runtime(
*,
config_path: str | Path,
weights_path: str | Path,
tokenizer_id: Optional[str],
repo_id: Optional[str],
mimi_id: Optional[str],
device: str,
dtype_pref: str,
) -> tuple[RuntimeContext, str, str]:
device_obj = torch.device(device)
if device_obj.type == "cuda":
cuda_matmul = torch.backends.cuda.matmul
if hasattr(cuda_matmul, "fp32_precision"):
cuda_matmul.fp32_precision = "tf32"
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Please use the new API settings",
)
torch.backends.cuda.matmul.allow_tf32 = True
else: # pragma: no cover - compatibility with older PyTorch
torch.backends.cuda.matmul.allow_tf32 = True
# Handle cuDNN conv TF32 settings (check if conv attribute exists first)
if hasattr(torch.backends.cudnn, "conv"):
cudnn_conv = torch.backends.cudnn.conv
if hasattr(cudnn_conv, "fp32_precision"):
cudnn_conv.fp32_precision = "tf32"
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Please use the new API settings",
)
torch.backends.cudnn.allow_tf32 = True
else:
torch.backends.cudnn.allow_tf32 = True
else:
# For older PyTorch versions without the conv attribute
torch.backends.cudnn.allow_tf32 = True
precision = resolve_precision(dtype_pref, device_obj)
config = load_config(config_path)
model = Dia2Model(config, precision)
state = load_file(str(weights_path))
model.load_state_dict(state)
model = model.to(device_obj)
tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
if tokenizer_ref is None:
raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_ref,
use_fast=False,
trust_remote_code=True,
)
mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
data_cfg = config.data
constants = TokenIds(
card=data_cfg.text_vocab_size,
new_word=data_cfg.text_new_word_token_id,
pad=data_cfg.text_pad_token_id,
bos=getattr(tokenizer, "bos_token_id", 1) or 1,
zero=data_cfg.text_zero_token_id,
spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
audio_pad=data_cfg.audio_pad_token_id,
audio_bos=data_cfg.audio_bos_token_id,
)
machine = StateMachine(
token_ids=constants,
second_stream_ahead=data_cfg.second_stream_ahead,
max_padding=6,
initial_padding=0,
)
audio_delays = list(data_cfg.delay_pattern)
audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
frame_rate = getattr(mimi, "frame_rate", 75.0)
runtime = RuntimeContext(
config=config,
precision=precision,
model=model,
tokenizer=tokenizer,
mimi=mimi,
device=device_obj,
machine=machine,
constants=constants,
audio_delays=audio_delays,
audio_delay_tensor=audio_delay_tensor,
frame_rate=frame_rate,
transformer_step=model.transformer.forward_step,
depformer_step=model.depformer.forward_step,
)
return runtime, tokenizer_ref, mimi_ref
__all__ = [
"RuntimeContext",
"build_runtime",
]
|