Masaaki Kawata commited on
Commit ·
aaffb76
1
Parent(s): 6209837
Remove Faster Irodori TTS runtime and related files
Browse files- Dockerfile +0 -1
- app.py +6 -4
- faster_irodori_tts/__init__.py +0 -17
- faster_irodori_tts/rf_graph.py +0 -511
- faster_irodori_tts/runtime.py +0 -290
Dockerfile
CHANGED
|
@@ -34,7 +34,6 @@ RUN python -m pip install --upgrade pip setuptools wheel \
|
|
| 34 |
|
| 35 |
COPY app.py .
|
| 36 |
COPY faster_qwen3_tts ./faster_qwen3_tts
|
| 37 |
-
COPY faster_irodori_tts ./faster_irodori_tts
|
| 38 |
COPY qwen_tts ./qwen_tts
|
| 39 |
COPY irodori_tts ./irodori_tts
|
| 40 |
|
|
|
|
| 34 |
|
| 35 |
COPY app.py .
|
| 36 |
COPY faster_qwen3_tts ./faster_qwen3_tts
|
|
|
|
| 37 |
COPY qwen_tts ./qwen_tts
|
| 38 |
COPY irodori_tts ./irodori_tts
|
| 39 |
|
app.py
CHANGED
|
@@ -22,15 +22,14 @@ from huggingface_hub import hf_hub_download
|
|
| 22 |
#from huggingface_hub import snapshot_download
|
| 23 |
#from qwen_tts import Qwen3TTSModel
|
| 24 |
from faster_qwen3_tts import FasterQwen3TTS
|
| 25 |
-
|
| 26 |
-
from faster_irodori_tts import FasterIrodoriTTSRuntime, RuntimeKey, SamplingRequest
|
| 27 |
|
| 28 |
|
| 29 |
load_dotenv(verbose=False)
|
| 30 |
|
| 31 |
#TTS_MODEL = Qwen3TTSModel.from_pretrained(snapshot_download('Qwen/Qwen3-TTS-12Hz-1.7B-Base', token=os.environ['HF_TOKEN']), device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], attn_implementation='kernels-community/flash-attn3')
|
| 32 |
TTS_MODEL = FasterQwen3TTS.from_pretrained('Qwen/Qwen3-TTS-12Hz-1.7B-Base')
|
| 33 |
-
IRODORI_TTS_RUNTIME: Optional[
|
| 34 |
WHISPER_MODEL = whisper.load_model('turbo', device='cpu', download_root=os.environ.get('WHISPER_CACHE_DIR'))
|
| 35 |
REFERENCE_AUDIO_TRANSCRIPTION_CACHE: dict[str, tuple[float, str, str]] = {}
|
| 36 |
REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK = threading.Lock()
|
|
@@ -235,7 +234,7 @@ def generate_voice_clone(model: str | None, input_text: str, language: str | Non
|
|
| 235 |
if IRODORI_TTS_RUNTIME is None:
|
| 236 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 237 |
precision = 'bf16' if device == 'cuda' else 'fp32'
|
| 238 |
-
IRODORI_TTS_RUNTIME =
|
| 239 |
checkpoint=hf_hub_download(repo_id='Aratako/Irodori-TTS-500M-v2', filename='model.safetensors'),
|
| 240 |
model_device=device,
|
| 241 |
codec_repo='Aratako/Semantic-DACVAE-Japanese-32dim',
|
|
@@ -245,6 +244,9 @@ def generate_voice_clone(model: str | None, input_text: str, language: str | Non
|
|
| 245 |
enable_watermark=False,
|
| 246 |
))
|
| 247 |
|
|
|
|
|
|
|
|
|
|
| 248 |
result = IRODORI_TTS_RUNTIME.synthesize(SamplingRequest(
|
| 249 |
text=input_text,
|
| 250 |
ref_wav=reference_audio,
|
|
|
|
| 22 |
#from huggingface_hub import snapshot_download
|
| 23 |
#from qwen_tts import Qwen3TTSModel
|
| 24 |
from faster_qwen3_tts import FasterQwen3TTS
|
| 25 |
+
from irodori_tts.inference_runtime import InferenceRuntime, RuntimeKey, SamplingRequest
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
load_dotenv(verbose=False)
|
| 29 |
|
| 30 |
#TTS_MODEL = Qwen3TTSModel.from_pretrained(snapshot_download('Qwen/Qwen3-TTS-12Hz-1.7B-Base', token=os.environ['HF_TOKEN']), device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], attn_implementation='kernels-community/flash-attn3')
|
| 31 |
TTS_MODEL = FasterQwen3TTS.from_pretrained('Qwen/Qwen3-TTS-12Hz-1.7B-Base')
|
| 32 |
+
IRODORI_TTS_RUNTIME: Optional[InferenceRuntime] = None
|
| 33 |
WHISPER_MODEL = whisper.load_model('turbo', device='cpu', download_root=os.environ.get('WHISPER_CACHE_DIR'))
|
| 34 |
REFERENCE_AUDIO_TRANSCRIPTION_CACHE: dict[str, tuple[float, str, str]] = {}
|
| 35 |
REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK = threading.Lock()
|
|
|
|
| 234 |
if IRODORI_TTS_RUNTIME is None:
|
| 235 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 236 |
precision = 'bf16' if device == 'cuda' else 'fp32'
|
| 237 |
+
IRODORI_TTS_RUNTIME = InferenceRuntime.from_key(RuntimeKey(
|
| 238 |
checkpoint=hf_hub_download(repo_id='Aratako/Irodori-TTS-500M-v2', filename='model.safetensors'),
|
| 239 |
model_device=device,
|
| 240 |
codec_repo='Aratako/Semantic-DACVAE-Japanese-32dim',
|
|
|
|
| 244 |
enable_watermark=False,
|
| 245 |
))
|
| 246 |
|
| 247 |
+
if sample_rate != 48000:
|
| 248 |
+
reference_audio = (_resample(reference_audio[0], sample_rate, 48000), 48000)
|
| 249 |
+
|
| 250 |
result = IRODORI_TTS_RUNTIME.synthesize(SamplingRequest(
|
| 251 |
text=input_text,
|
| 252 |
ref_wav=reference_audio,
|
faster_irodori_tts/__init__.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
"""CUDA Graph accelerated runtime helpers for Irodori-TTS."""
|
| 2 |
-
|
| 3 |
-
from .runtime import (
|
| 4 |
-
FasterInferenceRuntime,
|
| 5 |
-
FasterIrodoriTTSRuntime,
|
| 6 |
-
RuntimeKey,
|
| 7 |
-
SamplingRequest,
|
| 8 |
-
SamplingResult,
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
__all__ = [
|
| 12 |
-
"FasterInferenceRuntime",
|
| 13 |
-
"FasterIrodoriTTSRuntime",
|
| 14 |
-
"RuntimeKey",
|
| 15 |
-
"SamplingRequest",
|
| 16 |
-
"SamplingResult",
|
| 17 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
faster_irodori_tts/rf_graph.py
DELETED
|
@@ -1,511 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections import OrderedDict
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
from irodori_tts.rf import sample_euler_rf_cfg
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass(frozen=True)
|
| 12 |
-
class RFGraphSignature:
|
| 13 |
-
batch_size: int
|
| 14 |
-
sequence_length: int
|
| 15 |
-
latent_dim: int
|
| 16 |
-
text_len: int
|
| 17 |
-
speaker_len: int
|
| 18 |
-
num_steps: int
|
| 19 |
-
cfg_scale_text: float
|
| 20 |
-
cfg_scale_speaker: float
|
| 21 |
-
cfg_min_t: float
|
| 22 |
-
cfg_max_t: float
|
| 23 |
-
dtype: str
|
| 24 |
-
device: str
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
@dataclass
|
| 28 |
-
class RFGraphSampleResult:
|
| 29 |
-
latent: torch.Tensor
|
| 30 |
-
graph_used: bool
|
| 31 |
-
fallback_reason: str | None = None
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _device_key(device: torch.device) -> str:
|
| 35 |
-
index = 0 if device.index is None else int(device.index)
|
| 36 |
-
return f"{device.type}:{index}"
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _pad_reference_to_bucket(
|
| 40 |
-
ref_latent: torch.Tensor,
|
| 41 |
-
ref_mask: torch.Tensor,
|
| 42 |
-
*,
|
| 43 |
-
speaker_patch_size: int,
|
| 44 |
-
bucket_multiple: int,
|
| 45 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
-
if bucket_multiple <= 1:
|
| 47 |
-
return ref_latent, ref_mask
|
| 48 |
-
|
| 49 |
-
patch = max(1, int(speaker_patch_size))
|
| 50 |
-
current = int(ref_latent.shape[1])
|
| 51 |
-
after_patch = max(1, (current + patch - 1) // patch)
|
| 52 |
-
bucketed_after_patch = (
|
| 53 |
-
(after_patch + int(bucket_multiple) - 1) // int(bucket_multiple)
|
| 54 |
-
) * int(bucket_multiple)
|
| 55 |
-
target = bucketed_after_patch * patch
|
| 56 |
-
if target <= current:
|
| 57 |
-
return ref_latent, ref_mask
|
| 58 |
-
|
| 59 |
-
pad_len = target - current
|
| 60 |
-
latent_pad = torch.zeros(
|
| 61 |
-
(ref_latent.shape[0], pad_len, ref_latent.shape[2]),
|
| 62 |
-
device=ref_latent.device,
|
| 63 |
-
dtype=ref_latent.dtype,
|
| 64 |
-
)
|
| 65 |
-
mask_pad = torch.zeros(
|
| 66 |
-
(ref_mask.shape[0], pad_len),
|
| 67 |
-
device=ref_mask.device,
|
| 68 |
-
dtype=ref_mask.dtype,
|
| 69 |
-
)
|
| 70 |
-
return torch.cat([ref_latent, latent_pad], dim=1), torch.cat([ref_mask, mask_pad], dim=1)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _copy_context_kv(
|
| 74 |
-
dst: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
| 75 |
-
src: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
| 76 |
-
) -> None:
|
| 77 |
-
if len(dst) != len(src):
|
| 78 |
-
raise ValueError(f"Context KV layer count mismatch: graph={len(dst)} input={len(src)}")
|
| 79 |
-
for dst_layer, src_layer in zip(dst, src):
|
| 80 |
-
for dst_tensor, src_tensor in zip(dst_layer, src_layer):
|
| 81 |
-
if tuple(dst_tensor.shape) != tuple(src_tensor.shape):
|
| 82 |
-
raise ValueError(
|
| 83 |
-
"Context KV shape mismatch: "
|
| 84 |
-
f"graph={tuple(dst_tensor.shape)} input={tuple(src_tensor.shape)}"
|
| 85 |
-
)
|
| 86 |
-
dst_tensor.copy_(src_tensor)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class IrodoriRFGraph:
|
| 90 |
-
"""Captured RF Euler sampler for one fixed Irodori-TTS shape/configuration."""
|
| 91 |
-
|
| 92 |
-
def __init__(
|
| 93 |
-
self,
|
| 94 |
-
model,
|
| 95 |
-
signature: RFGraphSignature,
|
| 96 |
-
*,
|
| 97 |
-
num_warmup: int = 2,
|
| 98 |
-
) -> None:
|
| 99 |
-
self.model = model
|
| 100 |
-
self.signature = signature
|
| 101 |
-
self.device = model.device
|
| 102 |
-
self.device_index = 0 if self.device.index is None else int(self.device.index)
|
| 103 |
-
self.dtype = model.dtype
|
| 104 |
-
self.num_warmup = int(num_warmup)
|
| 105 |
-
self.cfg_batch_mult = 3
|
| 106 |
-
|
| 107 |
-
bsz = signature.batch_size
|
| 108 |
-
seq_len = signature.sequence_length
|
| 109 |
-
latent_dim = signature.latent_dim
|
| 110 |
-
cfg_bsz = bsz * self.cfg_batch_mult
|
| 111 |
-
text_dim = model.cfg.text_dim
|
| 112 |
-
speaker_dim = model.cfg.speaker_dim
|
| 113 |
-
|
| 114 |
-
self.x_buf = torch.zeros((bsz, seq_len, latent_dim), device=self.device, dtype=self.dtype)
|
| 115 |
-
self.x_cfg_buf = torch.zeros(
|
| 116 |
-
(cfg_bsz, seq_len, latent_dim), device=self.device, dtype=self.dtype
|
| 117 |
-
)
|
| 118 |
-
self.v_buf = torch.zeros_like(self.x_buf)
|
| 119 |
-
self.latent_mask = torch.ones((bsz, seq_len), device=self.device, dtype=torch.bool)
|
| 120 |
-
self.latent_mask_cfg = torch.ones((cfg_bsz, seq_len), device=self.device, dtype=torch.bool)
|
| 121 |
-
|
| 122 |
-
self.text_state_cond = torch.zeros(
|
| 123 |
-
(bsz, signature.text_len, text_dim), device=self.device, dtype=self.dtype
|
| 124 |
-
)
|
| 125 |
-
self.text_mask_cond = torch.zeros(
|
| 126 |
-
(bsz, signature.text_len), device=self.device, dtype=torch.bool
|
| 127 |
-
)
|
| 128 |
-
self.speaker_state_cond = torch.zeros(
|
| 129 |
-
(bsz, signature.speaker_len, speaker_dim), device=self.device, dtype=self.dtype
|
| 130 |
-
)
|
| 131 |
-
self.speaker_mask_cond = torch.zeros(
|
| 132 |
-
(bsz, signature.speaker_len), device=self.device, dtype=torch.bool
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
self.text_state_cfg = torch.zeros(
|
| 136 |
-
(cfg_bsz, signature.text_len, text_dim), device=self.device, dtype=self.dtype
|
| 137 |
-
)
|
| 138 |
-
self.text_mask_cfg = torch.zeros(
|
| 139 |
-
(cfg_bsz, signature.text_len), device=self.device, dtype=torch.bool
|
| 140 |
-
)
|
| 141 |
-
self.speaker_state_cfg = torch.zeros(
|
| 142 |
-
(cfg_bsz, signature.speaker_len, speaker_dim), device=self.device, dtype=self.dtype
|
| 143 |
-
)
|
| 144 |
-
self.speaker_mask_cfg = torch.zeros(
|
| 145 |
-
(cfg_bsz, signature.speaker_len), device=self.device, dtype=torch.bool
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
self.context_kv_cond = self._make_context_kv_buffers(bsz)
|
| 149 |
-
self.context_kv_cfg = self._make_context_kv_buffers(cfg_bsz)
|
| 150 |
-
|
| 151 |
-
init_scale = 0.999
|
| 152 |
-
t_schedule = torch.linspace(
|
| 153 |
-
1.0,
|
| 154 |
-
0.0,
|
| 155 |
-
signature.num_steps + 1,
|
| 156 |
-
device=self.device,
|
| 157 |
-
dtype=torch.float32,
|
| 158 |
-
) * init_scale
|
| 159 |
-
self.t_cond = [torch.full((bsz,), t_schedule[i], device=self.device, dtype=self.dtype)
|
| 160 |
-
for i in range(signature.num_steps)]
|
| 161 |
-
self.t_cfg = [self.t_cond[i].repeat(self.cfg_batch_mult)
|
| 162 |
-
for i in range(signature.num_steps)]
|
| 163 |
-
self.deltas = [
|
| 164 |
-
float((t_schedule[i + 1] - t_schedule[i]).detach().cpu())
|
| 165 |
-
for i in range(signature.num_steps)
|
| 166 |
-
]
|
| 167 |
-
self.use_cfg = [
|
| 168 |
-
bool(signature.cfg_min_t <= float(t_schedule[i].detach().cpu()) <= signature.cfg_max_t)
|
| 169 |
-
for i in range(signature.num_steps)
|
| 170 |
-
]
|
| 171 |
-
|
| 172 |
-
self.graph: torch.cuda.CUDAGraph | None = None
|
| 173 |
-
self.captured = False
|
| 174 |
-
|
| 175 |
-
def _make_context_kv_buffers(
|
| 176 |
-
self,
|
| 177 |
-
batch_size: int,
|
| 178 |
-
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
| 179 |
-
buffers: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
|
| 180 |
-
for block in self.model.blocks:
|
| 181 |
-
attn = block.attention
|
| 182 |
-
k_text = torch.zeros(
|
| 183 |
-
(batch_size, self.signature.text_len, attn.heads, attn.head_dim),
|
| 184 |
-
device=self.device,
|
| 185 |
-
dtype=self.dtype,
|
| 186 |
-
)
|
| 187 |
-
v_text = torch.zeros_like(k_text)
|
| 188 |
-
k_speaker = torch.zeros(
|
| 189 |
-
(batch_size, self.signature.speaker_len, attn.heads, attn.head_dim),
|
| 190 |
-
device=self.device,
|
| 191 |
-
dtype=self.dtype,
|
| 192 |
-
)
|
| 193 |
-
v_speaker = torch.zeros_like(k_speaker)
|
| 194 |
-
buffers.append((k_text, v_text, k_speaker, v_speaker))
|
| 195 |
-
return buffers
|
| 196 |
-
|
| 197 |
-
def _copy_cfg_x(self) -> None:
|
| 198 |
-
bsz = self.signature.batch_size
|
| 199 |
-
self.x_cfg_buf[:bsz].copy_(self.x_buf)
|
| 200 |
-
self.x_cfg_buf[bsz : 2 * bsz].copy_(self.x_buf)
|
| 201 |
-
self.x_cfg_buf[2 * bsz : 3 * bsz].copy_(self.x_buf)
|
| 202 |
-
|
| 203 |
-
def _full_loop(self) -> None:
|
| 204 |
-
bsz = self.signature.batch_size
|
| 205 |
-
scale_text = float(self.signature.cfg_scale_text)
|
| 206 |
-
scale_speaker = float(self.signature.cfg_scale_speaker)
|
| 207 |
-
cond_weight = 1.0 + scale_text + scale_speaker
|
| 208 |
-
|
| 209 |
-
for i in range(self.signature.num_steps):
|
| 210 |
-
if self.use_cfg[i]:
|
| 211 |
-
self._copy_cfg_x()
|
| 212 |
-
v_out = self.model.forward_with_encoded_conditions(
|
| 213 |
-
x_t=self.x_cfg_buf,
|
| 214 |
-
t=self.t_cfg[i],
|
| 215 |
-
text_state=self.text_state_cfg,
|
| 216 |
-
text_mask=self.text_mask_cfg,
|
| 217 |
-
speaker_state=self.speaker_state_cfg,
|
| 218 |
-
speaker_mask=self.speaker_mask_cfg,
|
| 219 |
-
latent_mask=self.latent_mask_cfg,
|
| 220 |
-
context_kv_cache=self.context_kv_cfg,
|
| 221 |
-
)
|
| 222 |
-
v_cond = v_out[:bsz]
|
| 223 |
-
v_uncond_text = v_out[bsz : 2 * bsz]
|
| 224 |
-
v_uncond_speaker = v_out[2 * bsz : 3 * bsz]
|
| 225 |
-
self.v_buf.copy_(v_cond)
|
| 226 |
-
self.v_buf.mul_(cond_weight)
|
| 227 |
-
self.v_buf.add_(v_uncond_text, alpha=-scale_text)
|
| 228 |
-
self.v_buf.add_(v_uncond_speaker, alpha=-scale_speaker)
|
| 229 |
-
else:
|
| 230 |
-
v_out = self.model.forward_with_encoded_conditions(
|
| 231 |
-
x_t=self.x_buf,
|
| 232 |
-
t=self.t_cond[i],
|
| 233 |
-
text_state=self.text_state_cond,
|
| 234 |
-
text_mask=self.text_mask_cond,
|
| 235 |
-
speaker_state=self.speaker_state_cond,
|
| 236 |
-
speaker_mask=self.speaker_mask_cond,
|
| 237 |
-
latent_mask=self.latent_mask,
|
| 238 |
-
context_kv_cache=self.context_kv_cond,
|
| 239 |
-
)
|
| 240 |
-
self.v_buf.copy_(v_out)
|
| 241 |
-
|
| 242 |
-
self.x_buf.add_(self.v_buf, alpha=self.deltas[i])
|
| 243 |
-
|
| 244 |
-
@torch.inference_mode()
|
| 245 |
-
def capture(self) -> None:
|
| 246 |
-
if self.captured:
|
| 247 |
-
return
|
| 248 |
-
|
| 249 |
-
# Populate module-side RoPE caches and allocator pools before capture.
|
| 250 |
-
for _ in range(max(1, self.num_warmup)):
|
| 251 |
-
self._full_loop()
|
| 252 |
-
torch.cuda.synchronize(self.device)
|
| 253 |
-
|
| 254 |
-
with torch.cuda.device(self.device_index):
|
| 255 |
-
stream = torch.cuda.Stream()
|
| 256 |
-
stream.wait_stream(torch.cuda.current_stream())
|
| 257 |
-
with torch.cuda.stream(stream):
|
| 258 |
-
for _ in range(max(1, self.num_warmup)):
|
| 259 |
-
self._full_loop()
|
| 260 |
-
torch.cuda.synchronize(self.device)
|
| 261 |
-
|
| 262 |
-
self.graph = torch.cuda.CUDAGraph()
|
| 263 |
-
with torch.cuda.graph(self.graph):
|
| 264 |
-
self._full_loop()
|
| 265 |
-
|
| 266 |
-
torch.cuda.current_stream().wait_stream(stream)
|
| 267 |
-
torch.cuda.synchronize(self.device)
|
| 268 |
-
self.captured = True
|
| 269 |
-
|
| 270 |
-
@torch.inference_mode()
|
| 271 |
-
def run(
|
| 272 |
-
self,
|
| 273 |
-
*,
|
| 274 |
-
x_t: torch.Tensor,
|
| 275 |
-
text_state_cond: torch.Tensor,
|
| 276 |
-
text_mask_cond: torch.Tensor,
|
| 277 |
-
speaker_state_cond: torch.Tensor,
|
| 278 |
-
speaker_mask_cond: torch.Tensor,
|
| 279 |
-
text_state_cfg: torch.Tensor,
|
| 280 |
-
text_mask_cfg: torch.Tensor,
|
| 281 |
-
speaker_state_cfg: torch.Tensor,
|
| 282 |
-
speaker_mask_cfg: torch.Tensor,
|
| 283 |
-
context_kv_cond: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
| 284 |
-
context_kv_cfg: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
| 285 |
-
) -> torch.Tensor:
|
| 286 |
-
if not self.captured or self.graph is None:
|
| 287 |
-
self.capture()
|
| 288 |
-
|
| 289 |
-
self.x_buf.copy_(x_t)
|
| 290 |
-
self.text_state_cond.copy_(text_state_cond)
|
| 291 |
-
self.text_mask_cond.copy_(text_mask_cond)
|
| 292 |
-
self.speaker_state_cond.copy_(speaker_state_cond)
|
| 293 |
-
self.speaker_mask_cond.copy_(speaker_mask_cond)
|
| 294 |
-
self.text_state_cfg.copy_(text_state_cfg)
|
| 295 |
-
self.text_mask_cfg.copy_(text_mask_cfg)
|
| 296 |
-
self.speaker_state_cfg.copy_(speaker_state_cfg)
|
| 297 |
-
self.speaker_mask_cfg.copy_(speaker_mask_cfg)
|
| 298 |
-
_copy_context_kv(self.context_kv_cond, context_kv_cond)
|
| 299 |
-
_copy_context_kv(self.context_kv_cfg, context_kv_cfg)
|
| 300 |
-
|
| 301 |
-
self.graph.replay()
|
| 302 |
-
return self.x_buf.clone()
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
class FasterIrodoriRFSampler:
|
| 306 |
-
"""Graph cache and safe fallback wrapper for Irodori RF sampling."""
|
| 307 |
-
|
| 308 |
-
def __init__(
|
| 309 |
-
self,
|
| 310 |
-
*,
|
| 311 |
-
max_graphs: int = 2,
|
| 312 |
-
speaker_bucket_multiple: int = 64,
|
| 313 |
-
num_warmup: int = 2,
|
| 314 |
-
) -> None:
|
| 315 |
-
self.max_graphs = max(1, int(max_graphs))
|
| 316 |
-
self.speaker_bucket_multiple = max(1, int(speaker_bucket_multiple))
|
| 317 |
-
self.num_warmup = max(1, int(num_warmup))
|
| 318 |
-
self._graphs: OrderedDict[RFGraphSignature, IrodoriRFGraph] = OrderedDict()
|
| 319 |
-
|
| 320 |
-
def _unsupported_reason(
|
| 321 |
-
self,
|
| 322 |
-
*,
|
| 323 |
-
model,
|
| 324 |
-
cfg_guidance_mode: str,
|
| 325 |
-
cfg_scale_text: float,
|
| 326 |
-
cfg_scale_speaker: float,
|
| 327 |
-
rescale_k: float | None,
|
| 328 |
-
rescale_sigma: float | None,
|
| 329 |
-
use_context_kv_cache: bool,
|
| 330 |
-
speaker_kv_scale: float | None,
|
| 331 |
-
) -> str | None:
|
| 332 |
-
if model.device.type != "cuda" or not torch.cuda.is_available():
|
| 333 |
-
return "CUDA Graph requires a CUDA device"
|
| 334 |
-
if str(cfg_guidance_mode).strip().lower() != "independent":
|
| 335 |
-
return "only cfg_guidance_mode='independent' is currently graphed"
|
| 336 |
-
if cfg_scale_text <= 0 or cfg_scale_speaker <= 0:
|
| 337 |
-
return "graph path currently expects both text and speaker CFG scales to be > 0"
|
| 338 |
-
if rescale_k is not None or rescale_sigma is not None:
|
| 339 |
-
return "rescale_k/rescale_sigma path is not graph-enabled"
|
| 340 |
-
if not use_context_kv_cache:
|
| 341 |
-
return "context_kv_cache=False is not graph-enabled"
|
| 342 |
-
if speaker_kv_scale is not None:
|
| 343 |
-
return "speaker_kv_scale path is not graph-enabled"
|
| 344 |
-
return None
|
| 345 |
-
|
| 346 |
-
def _get_graph(self, model, signature: RFGraphSignature) -> IrodoriRFGraph:
|
| 347 |
-
graph = self._graphs.get(signature)
|
| 348 |
-
if graph is not None:
|
| 349 |
-
self._graphs.move_to_end(signature)
|
| 350 |
-
return graph
|
| 351 |
-
|
| 352 |
-
graph = IrodoriRFGraph(model, signature, num_warmup=self.num_warmup)
|
| 353 |
-
graph.capture()
|
| 354 |
-
self._graphs[signature] = graph
|
| 355 |
-
self._graphs.move_to_end(signature)
|
| 356 |
-
while len(self._graphs) > self.max_graphs:
|
| 357 |
-
self._graphs.popitem(last=False)
|
| 358 |
-
return graph
|
| 359 |
-
|
| 360 |
-
@torch.inference_mode()
|
| 361 |
-
def sample(
|
| 362 |
-
self,
|
| 363 |
-
*,
|
| 364 |
-
model,
|
| 365 |
-
text_input_ids: torch.Tensor,
|
| 366 |
-
text_mask: torch.Tensor,
|
| 367 |
-
ref_latent: torch.Tensor,
|
| 368 |
-
ref_mask: torch.Tensor,
|
| 369 |
-
sequence_length: int,
|
| 370 |
-
num_steps: int = 40,
|
| 371 |
-
cfg_scale_text: float = 3.0,
|
| 372 |
-
cfg_scale_speaker: float = 5.0,
|
| 373 |
-
cfg_guidance_mode: str = "independent",
|
| 374 |
-
cfg_min_t: float = 0.5,
|
| 375 |
-
cfg_max_t: float = 1.0,
|
| 376 |
-
seed: int = 0,
|
| 377 |
-
truncation_factor: float | None = None,
|
| 378 |
-
rescale_k: float | None = None,
|
| 379 |
-
rescale_sigma: float | None = None,
|
| 380 |
-
use_context_kv_cache: bool = True,
|
| 381 |
-
speaker_kv_scale: float | None = None,
|
| 382 |
-
speaker_kv_max_layers: int | None = None,
|
| 383 |
-
speaker_kv_min_t: float | None = None,
|
| 384 |
-
) -> RFGraphSampleResult:
|
| 385 |
-
def fallback(reason: str) -> RFGraphSampleResult:
|
| 386 |
-
return RFGraphSampleResult(
|
| 387 |
-
latent=sample_euler_rf_cfg(
|
| 388 |
-
model=model,
|
| 389 |
-
text_input_ids=text_input_ids,
|
| 390 |
-
text_mask=text_mask,
|
| 391 |
-
ref_latent=ref_latent,
|
| 392 |
-
ref_mask=ref_mask,
|
| 393 |
-
sequence_length=sequence_length,
|
| 394 |
-
num_steps=num_steps,
|
| 395 |
-
cfg_scale_text=cfg_scale_text,
|
| 396 |
-
cfg_scale_speaker=cfg_scale_speaker,
|
| 397 |
-
cfg_guidance_mode=cfg_guidance_mode,
|
| 398 |
-
cfg_min_t=cfg_min_t,
|
| 399 |
-
cfg_max_t=cfg_max_t,
|
| 400 |
-
seed=seed,
|
| 401 |
-
truncation_factor=truncation_factor,
|
| 402 |
-
rescale_k=rescale_k,
|
| 403 |
-
rescale_sigma=rescale_sigma,
|
| 404 |
-
use_context_kv_cache=use_context_kv_cache,
|
| 405 |
-
speaker_kv_scale=speaker_kv_scale,
|
| 406 |
-
speaker_kv_max_layers=speaker_kv_max_layers,
|
| 407 |
-
speaker_kv_min_t=speaker_kv_min_t,
|
| 408 |
-
),
|
| 409 |
-
graph_used=False,
|
| 410 |
-
fallback_reason=reason,
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
reason = self._unsupported_reason(
|
| 414 |
-
model=model,
|
| 415 |
-
cfg_guidance_mode=cfg_guidance_mode,
|
| 416 |
-
cfg_scale_text=float(cfg_scale_text),
|
| 417 |
-
cfg_scale_speaker=float(cfg_scale_speaker),
|
| 418 |
-
rescale_k=rescale_k,
|
| 419 |
-
rescale_sigma=rescale_sigma,
|
| 420 |
-
use_context_kv_cache=bool(use_context_kv_cache),
|
| 421 |
-
speaker_kv_scale=speaker_kv_scale,
|
| 422 |
-
)
|
| 423 |
-
if reason is not None:
|
| 424 |
-
return fallback(reason)
|
| 425 |
-
|
| 426 |
-
device = model.device
|
| 427 |
-
dtype = model.dtype
|
| 428 |
-
batch_size = int(text_input_ids.shape[0])
|
| 429 |
-
latent_dim = model.cfg.patched_latent_dim
|
| 430 |
-
|
| 431 |
-
ref_latent, ref_mask = _pad_reference_to_bucket(
|
| 432 |
-
ref_latent,
|
| 433 |
-
ref_mask,
|
| 434 |
-
speaker_patch_size=model.cfg.speaker_patch_size,
|
| 435 |
-
bucket_multiple=self.speaker_bucket_multiple,
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
rng = torch.Generator(device=device).manual_seed(int(seed))
|
| 439 |
-
x_t = torch.randn(
|
| 440 |
-
(batch_size, int(sequence_length), latent_dim),
|
| 441 |
-
device=device,
|
| 442 |
-
dtype=dtype,
|
| 443 |
-
generator=rng,
|
| 444 |
-
)
|
| 445 |
-
if truncation_factor is not None:
|
| 446 |
-
x_t = x_t * float(truncation_factor)
|
| 447 |
-
|
| 448 |
-
text_state_cond, text_mask_cond, speaker_state_cond, speaker_mask_cond = (
|
| 449 |
-
model.encode_conditions(
|
| 450 |
-
text_input_ids=text_input_ids,
|
| 451 |
-
text_mask=text_mask,
|
| 452 |
-
ref_latent=ref_latent,
|
| 453 |
-
ref_mask=ref_mask,
|
| 454 |
-
)
|
| 455 |
-
)
|
| 456 |
-
text_state_uncond = torch.zeros_like(text_state_cond)
|
| 457 |
-
text_mask_uncond = torch.zeros_like(text_mask_cond)
|
| 458 |
-
speaker_state_uncond = torch.zeros_like(speaker_state_cond)
|
| 459 |
-
speaker_mask_uncond = torch.zeros_like(speaker_mask_cond)
|
| 460 |
-
|
| 461 |
-
text_state_cfg = torch.cat([text_state_cond, text_state_uncond, text_state_cond], dim=0)
|
| 462 |
-
text_mask_cfg = torch.cat([text_mask_cond, text_mask_uncond, text_mask_cond], dim=0)
|
| 463 |
-
speaker_state_cfg = torch.cat(
|
| 464 |
-
[speaker_state_cond, speaker_state_cond, speaker_state_uncond], dim=0
|
| 465 |
-
)
|
| 466 |
-
speaker_mask_cfg = torch.cat(
|
| 467 |
-
[speaker_mask_cond, speaker_mask_cond, speaker_mask_uncond], dim=0
|
| 468 |
-
)
|
| 469 |
-
|
| 470 |
-
context_kv_cond = model.build_context_kv_cache(
|
| 471 |
-
text_state=text_state_cond,
|
| 472 |
-
speaker_state=speaker_state_cond,
|
| 473 |
-
)
|
| 474 |
-
context_kv_cfg = model.build_context_kv_cache(
|
| 475 |
-
text_state=text_state_cfg,
|
| 476 |
-
speaker_state=speaker_state_cfg,
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
signature = RFGraphSignature(
|
| 480 |
-
batch_size=batch_size,
|
| 481 |
-
sequence_length=int(sequence_length),
|
| 482 |
-
latent_dim=int(latent_dim),
|
| 483 |
-
text_len=int(text_state_cond.shape[1]),
|
| 484 |
-
speaker_len=int(speaker_state_cond.shape[1]),
|
| 485 |
-
num_steps=int(num_steps),
|
| 486 |
-
cfg_scale_text=float(cfg_scale_text),
|
| 487 |
-
cfg_scale_speaker=float(cfg_scale_speaker),
|
| 488 |
-
cfg_min_t=float(cfg_min_t),
|
| 489 |
-
cfg_max_t=float(cfg_max_t),
|
| 490 |
-
dtype=str(dtype),
|
| 491 |
-
device=_device_key(device),
|
| 492 |
-
)
|
| 493 |
-
try:
|
| 494 |
-
graph = self._get_graph(model, signature)
|
| 495 |
-
latent = graph.run(
|
| 496 |
-
x_t=x_t,
|
| 497 |
-
text_state_cond=text_state_cond,
|
| 498 |
-
text_mask_cond=text_mask_cond,
|
| 499 |
-
speaker_state_cond=speaker_state_cond,
|
| 500 |
-
speaker_mask_cond=speaker_mask_cond,
|
| 501 |
-
text_state_cfg=text_state_cfg,
|
| 502 |
-
text_mask_cfg=text_mask_cfg,
|
| 503 |
-
speaker_state_cfg=speaker_state_cfg,
|
| 504 |
-
speaker_mask_cfg=speaker_mask_cfg,
|
| 505 |
-
context_kv_cond=context_kv_cond,
|
| 506 |
-
context_kv_cfg=context_kv_cfg,
|
| 507 |
-
)
|
| 508 |
-
except Exception as exc:
|
| 509 |
-
self._graphs.pop(signature, None)
|
| 510 |
-
return fallback(f"CUDA Graph capture/replay failed: {exc}")
|
| 511 |
-
return RFGraphSampleResult(latent=latent, graph_used=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
faster_irodori_tts/runtime.py
DELETED
|
@@ -1,290 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
import secrets
|
| 5 |
-
from collections.abc import Callable
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
|
| 9 |
-
from irodori_tts.codec import unpatchify_latent
|
| 10 |
-
from irodori_tts.inference_runtime import (
|
| 11 |
-
InferenceRuntime,
|
| 12 |
-
RuntimeKey,
|
| 13 |
-
SamplingRequest,
|
| 14 |
-
SamplingResult,
|
| 15 |
-
_measure_end,
|
| 16 |
-
_measure_start,
|
| 17 |
-
find_flattening_point,
|
| 18 |
-
resolve_cfg_scales,
|
| 19 |
-
)
|
| 20 |
-
from irodori_tts.text_normalization import normalize_text
|
| 21 |
-
|
| 22 |
-
from .rf_graph import FasterIrodoriRFSampler
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class FasterIrodoriTTSRuntime(InferenceRuntime):
|
| 26 |
-
"""Irodori runtime that uses CUDA Graphs for supported RF sampling requests."""
|
| 27 |
-
|
| 28 |
-
def __init__(self, **kwargs) -> None:
|
| 29 |
-
super().__init__(**kwargs)
|
| 30 |
-
self.rf_sampler = FasterIrodoriRFSampler()
|
| 31 |
-
|
| 32 |
-
def synthesize(
|
| 33 |
-
self,
|
| 34 |
-
req: SamplingRequest,
|
| 35 |
-
*,
|
| 36 |
-
log_fn: Callable[[str], None] | None = None,
|
| 37 |
-
) -> SamplingResult:
|
| 38 |
-
def _log(msg: str) -> None:
|
| 39 |
-
if log_fn is not None:
|
| 40 |
-
log_fn(msg)
|
| 41 |
-
|
| 42 |
-
messages: list[str] = []
|
| 43 |
-
_log(
|
| 44 |
-
(
|
| 45 |
-
"[faster_runtime] start synthesize "
|
| 46 |
-
"model_device={} model_precision={} codec_device={} codec_precision={} "
|
| 47 |
-
"watermark={} mode={} seconds={} steps={} seed={} candidates={} decode_mode={}"
|
| 48 |
-
).format(
|
| 49 |
-
self.key.model_device,
|
| 50 |
-
self.key.model_precision,
|
| 51 |
-
self.key.codec_device,
|
| 52 |
-
self.key.codec_precision,
|
| 53 |
-
self.codec.enable_watermark,
|
| 54 |
-
req.cfg_guidance_mode,
|
| 55 |
-
req.seconds,
|
| 56 |
-
req.num_steps,
|
| 57 |
-
"random" if req.seed is None else int(req.seed),
|
| 58 |
-
req.num_candidates,
|
| 59 |
-
req.decode_mode,
|
| 60 |
-
)
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
if req.seconds <= 0:
|
| 64 |
-
raise ValueError(f"seconds must be > 0, got {req.seconds}")
|
| 65 |
-
num_candidates = int(req.num_candidates)
|
| 66 |
-
if num_candidates <= 0:
|
| 67 |
-
raise ValueError(f"num_candidates must be > 0, got {num_candidates}")
|
| 68 |
-
decode_mode = str(req.decode_mode).strip().lower()
|
| 69 |
-
if decode_mode not in {"sequential", "batch"}:
|
| 70 |
-
raise ValueError(
|
| 71 |
-
f"Unsupported decode_mode={req.decode_mode!r}. Expected one of: sequential, batch."
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
raw_text = str(req.text)
|
| 75 |
-
normalized_text = normalize_text(raw_text).strip()
|
| 76 |
-
if normalized_text == "":
|
| 77 |
-
raise ValueError("text became empty after normalization.")
|
| 78 |
-
|
| 79 |
-
text_max_len = (
|
| 80 |
-
self.default_text_max_len if req.max_text_len is None else int(req.max_text_len)
|
| 81 |
-
)
|
| 82 |
-
if text_max_len <= 0:
|
| 83 |
-
raise ValueError(f"max_text_len must be > 0, got {text_max_len}")
|
| 84 |
-
|
| 85 |
-
truncation_factor = None if req.truncation_factor is None else float(req.truncation_factor)
|
| 86 |
-
rescale_k = None if req.rescale_k is None else float(req.rescale_k)
|
| 87 |
-
rescale_sigma = None if req.rescale_sigma is None else float(req.rescale_sigma)
|
| 88 |
-
if truncation_factor is not None and truncation_factor <= 0:
|
| 89 |
-
raise ValueError(f"truncation_factor must be > 0, got {truncation_factor}")
|
| 90 |
-
if (rescale_k is None) != (rescale_sigma is None):
|
| 91 |
-
raise ValueError("rescale_k and rescale_sigma must be set together.")
|
| 92 |
-
if rescale_k is not None and rescale_k <= 0:
|
| 93 |
-
raise ValueError(f"rescale_k must be > 0, got {rescale_k}")
|
| 94 |
-
if rescale_sigma is not None and rescale_sigma <= 0:
|
| 95 |
-
raise ValueError(f"rescale_sigma must be > 0, got {rescale_sigma}")
|
| 96 |
-
|
| 97 |
-
speaker_kv_scale = None if req.speaker_kv_scale is None else float(req.speaker_kv_scale)
|
| 98 |
-
speaker_kv_min_t = None
|
| 99 |
-
speaker_kv_max_layers = (
|
| 100 |
-
None if req.speaker_kv_max_layers is None else int(req.speaker_kv_max_layers)
|
| 101 |
-
)
|
| 102 |
-
if speaker_kv_scale is not None:
|
| 103 |
-
if speaker_kv_scale <= 0:
|
| 104 |
-
raise ValueError(f"speaker_kv_scale must be > 0, got {speaker_kv_scale}")
|
| 105 |
-
speaker_kv_min_t = 0.9 if req.speaker_kv_min_t is None else float(req.speaker_kv_min_t)
|
| 106 |
-
if not (0.0 <= speaker_kv_min_t <= 1.0):
|
| 107 |
-
raise ValueError(f"speaker_kv_min_t must be in [0, 1], got {speaker_kv_min_t}")
|
| 108 |
-
if speaker_kv_max_layers is not None and speaker_kv_max_layers < 0:
|
| 109 |
-
raise ValueError(
|
| 110 |
-
f"speaker_kv_max_layers must be >= 0 when specified, got {speaker_kv_max_layers}"
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
cfg_mode = str(req.cfg_guidance_mode).strip().lower()
|
| 114 |
-
if cfg_mode not in {"independent", "joint", "alternating"}:
|
| 115 |
-
raise ValueError(
|
| 116 |
-
f"Unsupported cfg_guidance_mode={req.cfg_guidance_mode!r}. "
|
| 117 |
-
"Expected one of: independent, joint, alternating."
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
cfg_scale_text, cfg_scale_speaker, scale_messages = resolve_cfg_scales(
|
| 121 |
-
cfg_guidance_mode=cfg_mode,
|
| 122 |
-
cfg_scale_text=req.cfg_scale_text,
|
| 123 |
-
cfg_scale_speaker=req.cfg_scale_speaker,
|
| 124 |
-
cfg_scale=req.cfg_scale,
|
| 125 |
-
)
|
| 126 |
-
messages.extend(scale_messages)
|
| 127 |
-
for msg in scale_messages:
|
| 128 |
-
_log(msg)
|
| 129 |
-
|
| 130 |
-
stage_timings: list[tuple[str, float]] = []
|
| 131 |
-
if req.seed is None:
|
| 132 |
-
used_seed = int(secrets.randbits(63))
|
| 133 |
-
msg = f"info: seed not specified; using random seed {used_seed}."
|
| 134 |
-
messages.append(msg)
|
| 135 |
-
_log(msg)
|
| 136 |
-
else:
|
| 137 |
-
used_seed = int(req.seed)
|
| 138 |
-
_log(f"[faster_runtime] using seed: {used_seed}")
|
| 139 |
-
post_load_t0 = _measure_start(self.model_device, self.codec_device)
|
| 140 |
-
|
| 141 |
-
with self._infer_lock, torch.inference_mode():
|
| 142 |
-
t0 = _measure_start(self.model_device)
|
| 143 |
-
text_ids, text_mask = self.tokenizer.batch_encode(
|
| 144 |
-
[normalized_text] * num_candidates,
|
| 145 |
-
max_length=text_max_len,
|
| 146 |
-
)
|
| 147 |
-
stage_sec = _measure_end(self.model_device, t0)
|
| 148 |
-
stage_timings.append(("tokenize_text", stage_sec))
|
| 149 |
-
_log(f"[faster_runtime] tokenize_text: {stage_sec * 1000.0:.1f} ms")
|
| 150 |
-
text_ids = text_ids.to(self.model_device)
|
| 151 |
-
text_mask = text_mask.to(self.model_device)
|
| 152 |
-
|
| 153 |
-
target_samples = int(float(req.seconds) * self.codec.sample_rate)
|
| 154 |
-
latent_steps = math.ceil(target_samples / int(self.codec.model.hop_length))
|
| 155 |
-
patched_steps = math.ceil(latent_steps / self.model_cfg.latent_patch_size)
|
| 156 |
-
|
| 157 |
-
if isinstance(self.train_cfg, dict):
|
| 158 |
-
fixed_steps = self.train_cfg.get("fixed_target_latent_steps")
|
| 159 |
-
if isinstance(fixed_steps, int) and fixed_steps > 0 and latent_steps > fixed_steps:
|
| 160 |
-
msg = (
|
| 161 |
-
f"warning: requested latent length ({latent_steps}) exceeds fixed_target_latent_steps ({fixed_steps}) "
|
| 162 |
-
"used in training. Long-tail stability may degrade."
|
| 163 |
-
)
|
| 164 |
-
messages.append(msg)
|
| 165 |
-
_log(msg)
|
| 166 |
-
|
| 167 |
-
t0 = _measure_start(self.model_device, self.codec_device)
|
| 168 |
-
msg_count_before_ref = len(messages)
|
| 169 |
-
ref_latent, ref_mask = self._load_reference_latent(
|
| 170 |
-
req=req,
|
| 171 |
-
batch_size=num_candidates,
|
| 172 |
-
messages=messages,
|
| 173 |
-
)
|
| 174 |
-
stage_sec = _measure_end(self.model_device, t0, self.codec_device)
|
| 175 |
-
stage_timings.append(("prepare_reference", stage_sec))
|
| 176 |
-
for msg in messages[msg_count_before_ref:]:
|
| 177 |
-
_log(msg)
|
| 178 |
-
_log(f"[faster_runtime] prepare_reference: {stage_sec * 1000.0:.1f} ms")
|
| 179 |
-
|
| 180 |
-
t0 = _measure_start(self.model_device)
|
| 181 |
-
sample_result = self.rf_sampler.sample(
|
| 182 |
-
model=self.model,
|
| 183 |
-
text_input_ids=text_ids,
|
| 184 |
-
text_mask=text_mask,
|
| 185 |
-
ref_latent=ref_latent,
|
| 186 |
-
ref_mask=ref_mask,
|
| 187 |
-
sequence_length=patched_steps,
|
| 188 |
-
num_steps=int(req.num_steps),
|
| 189 |
-
cfg_scale_text=cfg_scale_text,
|
| 190 |
-
cfg_scale_speaker=cfg_scale_speaker,
|
| 191 |
-
cfg_guidance_mode=cfg_mode,
|
| 192 |
-
cfg_min_t=float(req.cfg_min_t),
|
| 193 |
-
cfg_max_t=float(req.cfg_max_t),
|
| 194 |
-
seed=used_seed,
|
| 195 |
-
truncation_factor=truncation_factor,
|
| 196 |
-
rescale_k=rescale_k,
|
| 197 |
-
rescale_sigma=rescale_sigma,
|
| 198 |
-
use_context_kv_cache=bool(req.context_kv_cache),
|
| 199 |
-
speaker_kv_scale=speaker_kv_scale,
|
| 200 |
-
speaker_kv_max_layers=speaker_kv_max_layers,
|
| 201 |
-
speaker_kv_min_t=speaker_kv_min_t,
|
| 202 |
-
)
|
| 203 |
-
z_patched = sample_result.latent
|
| 204 |
-
stage_sec = _measure_end(self.model_device, t0)
|
| 205 |
-
stage_timings.append(("sample_rf", stage_sec))
|
| 206 |
-
if sample_result.graph_used:
|
| 207 |
-
_log(f"[faster_runtime] sample_rf (cuda_graph): {stage_sec * 1000.0:.1f} ms")
|
| 208 |
-
else:
|
| 209 |
-
msg = f"info: RF CUDA Graph fallback: {sample_result.fallback_reason}"
|
| 210 |
-
messages.append(msg)
|
| 211 |
-
_log(msg)
|
| 212 |
-
_log(f"[faster_runtime] sample_rf (fallback): {stage_sec * 1000.0:.1f} ms")
|
| 213 |
-
|
| 214 |
-
t0 = _measure_start(self.model_device)
|
| 215 |
-
z = unpatchify_latent(
|
| 216 |
-
z_patched,
|
| 217 |
-
patch_size=self.model_cfg.latent_patch_size,
|
| 218 |
-
latent_dim=self.model_cfg.latent_dim,
|
| 219 |
-
)
|
| 220 |
-
stage_sec = _measure_end(self.model_device, t0)
|
| 221 |
-
stage_timings.append(("unpatchify_latent", stage_sec))
|
| 222 |
-
_log(f"[faster_runtime] unpatchify_latent: {stage_sec * 1000.0:.1f} ms")
|
| 223 |
-
z = z[:, :latent_steps]
|
| 224 |
-
|
| 225 |
-
t0 = _measure_start(self.model_device, self.codec_device)
|
| 226 |
-
trimmed_audios: list[torch.Tensor] = []
|
| 227 |
-
if decode_mode == "batch":
|
| 228 |
-
audio_batch = self.codec.decode_latent(z).cpu()
|
| 229 |
-
for i in range(num_candidates):
|
| 230 |
-
audio_i = audio_batch[i]
|
| 231 |
-
max_samples = target_samples
|
| 232 |
-
if bool(req.trim_tail):
|
| 233 |
-
flattening_point = find_flattening_point(
|
| 234 |
-
z[i],
|
| 235 |
-
window_size=max(1, int(req.tail_window_size)),
|
| 236 |
-
std_threshold=float(req.tail_std_threshold),
|
| 237 |
-
mean_threshold=float(req.tail_mean_threshold),
|
| 238 |
-
)
|
| 239 |
-
flattening_samples = int(
|
| 240 |
-
flattening_point * int(self.codec.model.hop_length)
|
| 241 |
-
)
|
| 242 |
-
if flattening_samples > 0:
|
| 243 |
-
max_samples = min(max_samples, flattening_samples)
|
| 244 |
-
trimmed_audios.append(audio_i[:, :max_samples])
|
| 245 |
-
else:
|
| 246 |
-
for i in range(num_candidates):
|
| 247 |
-
audio_i = self.codec.decode_latent(z[i : i + 1]).cpu()[0]
|
| 248 |
-
max_samples = target_samples
|
| 249 |
-
if bool(req.trim_tail):
|
| 250 |
-
flattening_point = find_flattening_point(
|
| 251 |
-
z[i],
|
| 252 |
-
window_size=max(1, int(req.tail_window_size)),
|
| 253 |
-
std_threshold=float(req.tail_std_threshold),
|
| 254 |
-
mean_threshold=float(req.tail_mean_threshold),
|
| 255 |
-
)
|
| 256 |
-
flattening_samples = int(
|
| 257 |
-
flattening_point * int(self.codec.model.hop_length)
|
| 258 |
-
)
|
| 259 |
-
if flattening_samples > 0:
|
| 260 |
-
max_samples = min(max_samples, flattening_samples)
|
| 261 |
-
trimmed_audios.append(audio_i[:, :max_samples])
|
| 262 |
-
stage_sec = _measure_end(self.model_device, t0, self.codec_device)
|
| 263 |
-
stage_timings.append(("decode_latent", stage_sec))
|
| 264 |
-
_log(f"[faster_runtime] decode_latent ({decode_mode}): {stage_sec * 1000.0:.1f} ms")
|
| 265 |
-
|
| 266 |
-
total_to_decode = _measure_end(self.model_device, post_load_t0, self.codec_device)
|
| 267 |
-
_log(f"[faster_runtime] total_to_decode: {total_to_decode:.3f} s")
|
| 268 |
-
|
| 269 |
-
_log("[faster_runtime] done synthesize")
|
| 270 |
-
return SamplingResult(
|
| 271 |
-
audio=trimmed_audios[0],
|
| 272 |
-
audios=trimmed_audios,
|
| 273 |
-
sample_rate=int(self.codec.sample_rate),
|
| 274 |
-
stage_timings=stage_timings,
|
| 275 |
-
total_to_decode=total_to_decode,
|
| 276 |
-
used_seed=used_seed,
|
| 277 |
-
messages=messages,
|
| 278 |
-
)
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
# Backward-friendly alias for callers that prefer an InferenceRuntime-like name.
|
| 282 |
-
FasterInferenceRuntime = FasterIrodoriTTSRuntime
|
| 283 |
-
|
| 284 |
-
__all__ = [
|
| 285 |
-
"FasterIrodoriTTSRuntime",
|
| 286 |
-
"FasterInferenceRuntime",
|
| 287 |
-
"RuntimeKey",
|
| 288 |
-
"SamplingRequest",
|
| 289 |
-
"SamplingResult",
|
| 290 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|