dinhthuan commited on
Commit
befd0e1
·
verified ·
1 Parent(s): 5bf8c9e

Upload vizipvoice.py

Browse files
Files changed (1) hide show
  1. vizipvoice.py +556 -0
vizipvoice.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ import tempfile
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Optional, Union
8
+
9
+ import safetensors.torch
10
+ import torch
11
+ import torchaudio
12
+ from huggingface_hub import hf_hub_download, list_repo_files
13
+ from lhotse.utils import fix_random_seed
14
+
15
+ from zipvoice.bin.infer_zipvoice import generate_sentence, get_vocoder
16
+ from zipvoice.models.zipvoice import ZipVoice
17
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
18
+ from zipvoice.utils.checkpoint import load_checkpoint
19
+ from zipvoice.utils.feature import VocosFbank
20
+
21
+ DEFAULT_REPO_ID = "contextboxai/ViZipvoice"
22
+ DEFAULT_CHECKPOINT_NAME = "latest"
23
+ CHECKPOINT_RE = re.compile(r"^checkpoint-(\d+)\.pt$")
24
+ SENTENCE_SPLIT_PATTERN = re.compile(r"[^.??。]+(?:[.??。]+|$)", re.S)
25
+ PUNCTUATION_NO_SPACE_BEFORE = r",.;:!?…%"
26
+ OPENING_QUOTES_AND_BRACKETS = r"\(\[\{«“‘"
27
+ CLOSING_QUOTES_AND_BRACKETS = r"\)\]\}»”’"
28
+
29
+
30
+ def _resolve_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
31
+ if device is not None:
32
+ return torch.device(device)
33
+ if torch.cuda.is_available():
34
+ return torch.device("cuda", 0)
35
+ if torch.backends.mps.is_available():
36
+ return torch.device("mps")
37
+ return torch.device("cpu")
38
+
39
+
40
+ def _download_model_files(
41
+ repo_id: str,
42
+ revision: Optional[str],
43
+ checkpoint_name: str,
44
+ ) -> tuple[Path, Path, Path]:
45
+ checkpoint_name = _resolve_hf_checkpoint_name(
46
+ repo_id=repo_id,
47
+ revision=revision,
48
+ checkpoint_name=checkpoint_name,
49
+ )
50
+ checkpoint_path = Path(
51
+ hf_hub_download(
52
+ repo_id=repo_id,
53
+ filename=checkpoint_name,
54
+ revision=revision,
55
+ )
56
+ )
57
+ model_config_path = _download_config_file(
58
+ repo_id=repo_id,
59
+ revision=revision,
60
+ )
61
+ token_file = Path(
62
+ hf_hub_download(
63
+ repo_id=repo_id,
64
+ filename="tokens.txt",
65
+ revision=revision,
66
+ )
67
+ )
68
+ return checkpoint_path, model_config_path, token_file
69
+
70
+
71
+ def _download_config_file(repo_id: str, revision: Optional[str]) -> Path:
72
+ last_error = None
73
+ for filename in ("config.json", "model.json"):
74
+ try:
75
+ return Path(
76
+ hf_hub_download(
77
+ repo_id=repo_id,
78
+ filename=filename,
79
+ revision=revision,
80
+ )
81
+ )
82
+ except Exception as exc:
83
+ last_error = exc
84
+ raise FileNotFoundError("No config.json or model.json file found.") from last_error
85
+
86
+
87
+ def _checkpoint_step(filename: str) -> int:
88
+ match = CHECKPOINT_RE.match(Path(filename).name)
89
+ return int(match.group(1)) if match else -1
90
+
91
+
92
+ def _select_latest_checkpoint(filenames: list[str]) -> str:
93
+ checkpoints = [
94
+ filename for filename in filenames if _checkpoint_step(filename) >= 0
95
+ ]
96
+ if checkpoints:
97
+ return max(checkpoints, key=lambda filename: _checkpoint_step(filename))
98
+ raise FileNotFoundError("No checkpoint-<step>.pt file found.")
99
+
100
+
101
+ def _resolve_hf_checkpoint_name(
102
+ repo_id: str,
103
+ revision: Optional[str],
104
+ checkpoint_name: str,
105
+ ) -> str:
106
+ if checkpoint_name != "latest":
107
+ return checkpoint_name
108
+ filenames = list_repo_files(repo_id=repo_id, revision=revision)
109
+ return _select_latest_checkpoint(filenames)
110
+
111
+
112
+ def _resolve_local_checkpoint_path(
113
+ model_dir: Path,
114
+ checkpoint_name: str,
115
+ ) -> Path:
116
+ if checkpoint_name != "latest":
117
+ return model_dir / checkpoint_name
118
+ filenames = [path.name for path in model_dir.iterdir() if path.is_file()]
119
+ return model_dir / _select_latest_checkpoint(filenames)
120
+
121
+
122
+ def _resolve_local_config_path(model_dir: Path) -> Path:
123
+ for filename in ("config.json", "model.json"):
124
+ config_path = model_dir / filename
125
+ if config_path.is_file():
126
+ return config_path
127
+ raise FileNotFoundError(f"No config.json or model.json file found in {model_dir}")
128
+
129
+
130
+ def cleanup_vietnamese_spacing(text: str) -> str:
131
+ text = re.sub(r"\s+", " ", text.strip())
132
+ text = re.sub(
133
+ rf"\s+([{re.escape(PUNCTUATION_NO_SPACE_BEFORE)}])",
134
+ r"\1",
135
+ text,
136
+ )
137
+ text = re.sub(
138
+ rf"\s+([{CLOSING_QUOTES_AND_BRACKETS}])",
139
+ r"\1",
140
+ text,
141
+ )
142
+ text = re.sub(
143
+ rf"([{OPENING_QUOTES_AND_BRACKETS}])\s+",
144
+ r"\1",
145
+ text,
146
+ )
147
+ text = re.sub(
148
+ rf"([{re.escape(PUNCTUATION_NO_SPACE_BEFORE)}])"
149
+ rf"([^\s{CLOSING_QUOTES_AND_BRACKETS}])",
150
+ r"\1 \2",
151
+ text,
152
+ )
153
+ return text.strip()
154
+
155
+
156
+ def normalize_vietnamese_text(text: str, enabled: bool = True) -> str:
157
+ if not enabled:
158
+ return text.strip()
159
+
160
+ try:
161
+ from soe_vinorm import normalize_text
162
+ except ImportError as exc:
163
+ raise RuntimeError(
164
+ "Vietnamese normalization requires soe-vinorm. "
165
+ "Install it with `pip install soe-vinorm`."
166
+ ) from exc
167
+
168
+ return cleanup_vietnamese_spacing(normalize_text(text))
169
+
170
+
171
+ def split_text_into_sentences(text: str) -> list[str]:
172
+ text = text.strip()
173
+ if not text:
174
+ return []
175
+
176
+ sentences = [
177
+ match.group(0).strip()
178
+ for match in SENTENCE_SPLIT_PATTERN.finditer(text)
179
+ if match.group(0).strip()
180
+ ]
181
+ return sentences or [text]
182
+
183
+
184
+ def count_sentence_words(sentence: str) -> int:
185
+ return len(re.findall(r"\w+", sentence, flags=re.UNICODE))
186
+
187
+
188
+ def get_sentence_inference_params(
189
+ sentence: str,
190
+ base_num_step: int,
191
+ base_speed: float,
192
+ ) -> tuple[int, float, int]:
193
+ word_count = count_sentence_words(sentence)
194
+ if word_count == 1:
195
+ return max(base_num_step, 24), 0.6, word_count
196
+ if 2 <= word_count <= 4:
197
+ return base_num_step, 0.8, word_count
198
+ return base_num_step, base_speed, word_count
199
+
200
+
201
+ def match_audio_channels(
202
+ first: torch.Tensor,
203
+ second: torch.Tensor,
204
+ ) -> tuple[torch.Tensor, torch.Tensor]:
205
+ if first.shape[0] == second.shape[0]:
206
+ return first, second
207
+ if first.shape[0] == 1:
208
+ return first.expand(second.shape[0], -1), second
209
+ if second.shape[0] == 1:
210
+ return first, second.expand(first.shape[0], -1)
211
+
212
+ channels = min(first.shape[0], second.shape[0])
213
+ return first[:channels], second[:channels]
214
+
215
+
216
+ def append_with_crossfade(
217
+ first: torch.Tensor,
218
+ second: torch.Tensor,
219
+ crossfade_samples: int,
220
+ ) -> torch.Tensor:
221
+ first, second = match_audio_channels(first, second)
222
+ fade_len = min(crossfade_samples, first.shape[1], second.shape[1])
223
+ if fade_len <= 0:
224
+ return torch.cat([first, second], dim=1)
225
+
226
+ fade_out = torch.linspace(
227
+ 1.0,
228
+ 0.0,
229
+ fade_len,
230
+ dtype=first.dtype,
231
+ device=first.device,
232
+ ).unsqueeze(0)
233
+ fade_in = torch.linspace(
234
+ 0.0,
235
+ 1.0,
236
+ fade_len,
237
+ dtype=second.dtype,
238
+ device=second.device,
239
+ ).unsqueeze(0)
240
+ overlap = first[:, -fade_len:] * fade_out + second[:, :fade_len] * fade_in
241
+ return torch.cat([first[:, :-fade_len], overlap, second[:, fade_len:]], dim=1)
242
+
243
+
244
+ def apply_fade(audio: torch.Tensor, fade_in_samples: int, fade_out_samples: int) -> torch.Tensor:
245
+ if audio.numel() == 0:
246
+ return audio
247
+
248
+ audio = audio.clone()
249
+ if fade_in_samples > 0:
250
+ fade_len = min(fade_in_samples, audio.shape[1])
251
+ fade = torch.linspace(
252
+ 0.0,
253
+ 1.0,
254
+ fade_len,
255
+ dtype=audio.dtype,
256
+ device=audio.device,
257
+ ).unsqueeze(0)
258
+ audio[:, :fade_len] *= fade
259
+
260
+ if fade_out_samples > 0:
261
+ fade_len = min(fade_out_samples, audio.shape[1])
262
+ fade = torch.linspace(
263
+ 1.0,
264
+ 0.0,
265
+ fade_len,
266
+ dtype=audio.dtype,
267
+ device=audio.device,
268
+ ).unsqueeze(0)
269
+ audio[:, -fade_len:] *= fade
270
+
271
+ return audio
272
+
273
+
274
+ def postprocess_audio_segments(
275
+ segment_paths: list[Path],
276
+ output_path: Path,
277
+ sampling_rate: int,
278
+ crossfade_ms: int,
279
+ silence_ms: int,
280
+ fade_in_ms: int,
281
+ fade_out_ms: int,
282
+ ) -> None:
283
+ if not segment_paths:
284
+ raise RuntimeError("No generated audio segments to postprocess.")
285
+
286
+ crossfade_samples = int(sampling_rate * max(crossfade_ms, 0) / 1000)
287
+ silence_samples = int(sampling_rate * max(silence_ms, 0) / 1000)
288
+ fade_in_samples = int(sampling_rate * max(fade_in_ms, 0) / 1000)
289
+ fade_out_samples = int(sampling_rate * max(fade_out_ms, 0) / 1000)
290
+
291
+ combined = None
292
+ for index, segment_path in enumerate(segment_paths):
293
+ audio, sr = torchaudio.load(str(segment_path))
294
+ if sr != sampling_rate:
295
+ audio = torchaudio.functional.resample(audio, sr, sampling_rate)
296
+
297
+ if index < len(segment_paths) - 1 and silence_samples > 0:
298
+ silence = torch.zeros(
299
+ audio.shape[0],
300
+ silence_samples,
301
+ dtype=audio.dtype,
302
+ device=audio.device,
303
+ )
304
+ audio = torch.cat([audio, silence], dim=1)
305
+
306
+ if combined is None:
307
+ combined = audio
308
+ else:
309
+ combined = append_with_crossfade(combined, audio, crossfade_samples)
310
+
311
+ combined = apply_fade(
312
+ combined,
313
+ fade_in_samples=fade_in_samples,
314
+ fade_out_samples=fade_out_samples,
315
+ )
316
+ combined = combined.clamp(min=-1.0, max=1.0).cpu()
317
+ torchaudio.save(str(output_path), combined, sampling_rate)
318
+
319
+
320
+ def wav_seconds(path: Union[str, Path]) -> float:
321
+ try:
322
+ import soundfile as sf
323
+
324
+ info = sf.info(str(path))
325
+ return float(info.frames) / float(info.samplerate)
326
+ except Exception:
327
+ audio, sr = torchaudio.load(str(path))
328
+ return float(audio.shape[-1]) / float(sr)
329
+
330
+
331
+ class ViZipVoiceTTS:
332
+ """Small wrapper for Vietnamese ZipVoice inference.
333
+
334
+ The wrapper downloads model files from Hugging Face by default, builds the
335
+ ZipVoice model with the Vietnamese character tokenizer, and exposes a
336
+ single synthesize method.
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ repo_id: str = DEFAULT_REPO_ID,
342
+ revision: Optional[str] = None,
343
+ model_dir: Optional[Union[str, Path]] = None,
344
+ checkpoint_name: str = DEFAULT_CHECKPOINT_NAME,
345
+ vocoder_path: Optional[Union[str, Path]] = None,
346
+ device: Optional[Union[str, torch.device]] = None,
347
+ use_fp16: bool = True,
348
+ num_threads: int = 1,
349
+ ) -> None:
350
+ try:
351
+ torch.set_num_threads(num_threads)
352
+ torch.set_num_interop_threads(num_threads)
353
+ except RuntimeError:
354
+ logging.debug("PyTorch thread settings were already initialized.")
355
+
356
+ self.repo_id = repo_id
357
+ self.revision = revision
358
+ self.device = _resolve_device(device)
359
+ self.use_fp16 = bool(use_fp16 and self.device.type == "cuda")
360
+
361
+ if model_dir is None:
362
+ checkpoint_path, model_config_path, token_file = _download_model_files(
363
+ repo_id=repo_id,
364
+ revision=revision,
365
+ checkpoint_name=checkpoint_name,
366
+ )
367
+ else:
368
+ model_dir = Path(model_dir)
369
+ checkpoint_path = _resolve_local_checkpoint_path(
370
+ model_dir=model_dir,
371
+ checkpoint_name=checkpoint_name,
372
+ )
373
+ model_config_path = _resolve_local_config_path(model_dir)
374
+ token_file = model_dir / "tokens.txt"
375
+
376
+ self.checkpoint_path = Path(checkpoint_path)
377
+ self.model_config_path = Path(model_config_path)
378
+ self.token_file = Path(token_file)
379
+ self._validate_model_files()
380
+
381
+ with self.model_config_path.open("r", encoding="utf-8") as f:
382
+ self.model_config = json.load(f)
383
+
384
+ self.tokenizer = SimpleTokenizer(token_file=str(self.token_file))
385
+ self.model = ZipVoice(
386
+ **self.model_config["model"],
387
+ vocab_size=self.tokenizer.vocab_size,
388
+ pad_id=self.tokenizer.pad_id,
389
+ )
390
+ self._load_checkpoint()
391
+ self.model.to(self.device)
392
+ self.model.eval()
393
+
394
+ self.feature_extractor = VocosFbank()
395
+ self.vocoder = get_vocoder(str(vocoder_path) if vocoder_path else None)
396
+ self.vocoder.to(self.device)
397
+ self.vocoder.eval()
398
+ self.sampling_rate = int(self.model_config["feature"]["sampling_rate"])
399
+
400
+ logging.info(
401
+ "Loaded ViZipVoice from %s on %s | fp16 autocast: %s",
402
+ self.checkpoint_path,
403
+ self.device,
404
+ self.use_fp16,
405
+ )
406
+
407
+ def _validate_model_files(self) -> None:
408
+ missing = [
409
+ path
410
+ for path in [self.checkpoint_path, self.model_config_path, self.token_file]
411
+ if not path.is_file()
412
+ ]
413
+ if missing:
414
+ missing_text = ", ".join(str(path) for path in missing)
415
+ raise FileNotFoundError(f"Missing ViZipVoice model file(s): {missing_text}")
416
+
417
+ def _load_checkpoint(self) -> None:
418
+ suffix = self.checkpoint_path.suffix.lower()
419
+ if suffix == ".safetensors":
420
+ safetensors.torch.load_model(self.model, str(self.checkpoint_path))
421
+ elif suffix == ".pt":
422
+ load_checkpoint(
423
+ filename=self.checkpoint_path,
424
+ model=self.model,
425
+ strict=True,
426
+ )
427
+ else:
428
+ raise ValueError(f"Unsupported checkpoint format: {self.checkpoint_path}")
429
+
430
+ @torch.inference_mode()
431
+ def synthesize(
432
+ self,
433
+ prompt_wav: Union[str, Path],
434
+ prompt_text: str,
435
+ text: str,
436
+ output_path: Union[str, Path] = "output.wav",
437
+ num_step: int = 16,
438
+ guidance_scale: float = 1.0,
439
+ speed: float = 1.0,
440
+ t_shift: float = 0.5,
441
+ target_rms: float = 0.1,
442
+ feat_scale: float = 0.1,
443
+ max_duration: float = 100,
444
+ remove_long_sil: bool = False,
445
+ seed: Optional[int] = 666,
446
+ normalize_vietnamese: bool = True,
447
+ split_sentences: bool = True,
448
+ crossfade_ms: int = 80,
449
+ silence_ms: int = 180,
450
+ fade_in_ms: int = 20,
451
+ fade_out_ms: int = 80,
452
+ ) -> dict:
453
+ if seed is not None and seed >= 0:
454
+ fix_random_seed(int(seed))
455
+
456
+ prompt_text = normalize_vietnamese_text(
457
+ prompt_text,
458
+ enabled=normalize_vietnamese,
459
+ )
460
+ text = normalize_vietnamese_text(
461
+ text,
462
+ enabled=normalize_vietnamese,
463
+ )
464
+ target_sentences = split_text_into_sentences(text) if split_sentences else [text]
465
+ if not target_sentences:
466
+ raise ValueError("No valid text to synthesize.")
467
+
468
+ output_path = Path(output_path)
469
+ output_path.parent.mkdir(parents=True, exist_ok=True)
470
+
471
+ segment_paths = []
472
+ segment_metrics = []
473
+ segment_settings = []
474
+ start_time = time.time()
475
+ with tempfile.TemporaryDirectory(
476
+ prefix=f"{output_path.stem}_segments_",
477
+ dir=str(output_path.parent),
478
+ ) as segment_dir_name:
479
+ segment_dir = Path(segment_dir_name)
480
+ with torch.autocast(
481
+ device_type="cuda",
482
+ dtype=torch.float16,
483
+ enabled=self.use_fp16,
484
+ ):
485
+ for index, sentence in enumerate(target_sentences, start=1):
486
+ sentence_num_step, sentence_speed, word_count = (
487
+ get_sentence_inference_params(
488
+ sentence=sentence,
489
+ base_num_step=int(num_step),
490
+ base_speed=float(speed),
491
+ )
492
+ )
493
+ segment_path = segment_dir / f"segment_{index:03d}.wav"
494
+ metrics = generate_sentence(
495
+ save_path=str(segment_path),
496
+ prompt_text=prompt_text,
497
+ prompt_wav=str(prompt_wav),
498
+ text=sentence,
499
+ model=self.model,
500
+ vocoder=self.vocoder,
501
+ tokenizer=self.tokenizer,
502
+ feature_extractor=self.feature_extractor,
503
+ device=self.device,
504
+ num_step=sentence_num_step,
505
+ guidance_scale=float(guidance_scale),
506
+ speed=sentence_speed,
507
+ t_shift=float(t_shift),
508
+ target_rms=float(target_rms),
509
+ feat_scale=float(feat_scale),
510
+ sampling_rate=self.sampling_rate,
511
+ max_duration=float(max_duration),
512
+ remove_long_sil=bool(remove_long_sil),
513
+ )
514
+ segment_paths.append(segment_path)
515
+ segment_metrics.append(metrics)
516
+ segment_settings.append(
517
+ {
518
+ "index": index,
519
+ "word_count": word_count,
520
+ "speed": sentence_speed,
521
+ "num_step": sentence_num_step,
522
+ "text": sentence,
523
+ }
524
+ )
525
+
526
+ postprocess_audio_segments(
527
+ segment_paths=segment_paths,
528
+ output_path=output_path,
529
+ sampling_rate=self.sampling_rate,
530
+ crossfade_ms=int(crossfade_ms),
531
+ silence_ms=int(silence_ms),
532
+ fade_in_ms=int(fade_in_ms),
533
+ fade_out_ms=int(fade_out_ms),
534
+ )
535
+
536
+ elapsed = time.time() - start_time
537
+ audio_seconds = wav_seconds(output_path)
538
+ t_no_vocoder = sum(item.get("t_no_vocoder", 0.0) for item in segment_metrics)
539
+ t_vocoder = sum(item.get("t_vocoder", 0.0) for item in segment_metrics)
540
+ rtf = elapsed / audio_seconds if audio_seconds else 0.0
541
+ return {
542
+ "t": elapsed,
543
+ "t_no_vocoder": t_no_vocoder,
544
+ "t_vocoder": t_vocoder,
545
+ "wav_seconds": audio_seconds,
546
+ "rtf": rtf,
547
+ "rtf_no_vocoder": t_no_vocoder / audio_seconds if audio_seconds else 0.0,
548
+ "rtf_vocoder": t_vocoder / audio_seconds if audio_seconds else 0.0,
549
+ "segments": len(segment_paths),
550
+ "segment_settings": segment_settings,
551
+ "segment_metrics": segment_metrics,
552
+ "crossfade_ms": int(crossfade_ms),
553
+ "silence_ms": int(silence_ms),
554
+ "fade_in_ms": int(fade_in_ms),
555
+ "fade_out_ms": int(fade_out_ms),
556
+ }