dinhthuan commited on
Commit
a130ca2
·
1 Parent(s): e7666b9

fix: remove click artifacts using VAD and root cause token trimming

Browse files
Files changed (1) hide show
  1. viterbox/tts.py +168 -14
viterbox/tts.py CHANGED
@@ -36,6 +36,32 @@ REPO_ID = "dolly-vn/viterbox"
36
  WAVS_DIR = Path("wavs")
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def get_random_voice() -> Optional[Path]:
40
  """Get a random voice file from wavs folder"""
41
  if WAVS_DIR.exists():
@@ -46,6 +72,46 @@ def get_random_voice() -> Optional[Path]:
46
  return None
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def normalize_text(text: str, language: str = "vi") -> str:
50
  """Normalize Vietnamese text (numbers, abbreviations, etc.)"""
51
  if language == "vi" and HAS_VINORM and _normalizer is not None:
@@ -82,11 +148,89 @@ def _split_text_to_sentences(text: str) -> List[str]:
82
 
83
 
84
  def trim_silence(audio: np.ndarray, sr: int, top_db: int = 30) -> np.ndarray:
85
- """Trim silence from audio."""
86
  trimmed, _ = librosa.effects.trim(audio, top_db=top_db)
87
  return trimmed
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def apply_fade_out(audio: np.ndarray, sr: int, fade_duration: float = 0.01) -> np.ndarray:
91
  """
92
  Apply smooth fade-out to prevent click artifacts at the end of audio.
@@ -431,7 +575,9 @@ class Viterbox:
431
  top_p: float,
432
  repetition_penalty: float,
433
  ) -> np.ndarray:
434
- """Generate speech for a single sentence."""
 
 
435
  # Tokenize text with language prefix
436
  text_tokens = self.tokenizer.text_to_tokens(text, language_id=language).to(self.device)
437
 
@@ -448,12 +594,12 @@ class Viterbox:
448
  use_autocast = self.device in ['cuda', 'mps']
449
  device_type = 'cuda' if self.device == 'cuda' else 'mps'
450
 
451
- with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=(self.device==use_autocast)):
452
  # Generate speech tokens with T3
453
  speech_tokens = self.t3.inference(
454
  t3_cond=self.conds.t3,
455
  text_tokens=text_tokens,
456
- max_new_tokens=600,
457
  temperature=temperature,
458
  cfg_weight=cfg_weight,
459
  repetition_penalty=repetition_penalty,
@@ -463,14 +609,20 @@ class Viterbox:
463
  # Extract only the conditional batch and filter invalid tokens
464
  speech_tokens = speech_tokens[0]
465
  speech_tokens = drop_invalid_tokens(speech_tokens)
 
 
 
 
 
 
466
  speech_tokens = speech_tokens.to(self.device)
467
 
468
- # Generate waveform with S3Gen
469
- wav, _ = self.s3gen.inference(
470
- speech_tokens=speech_tokens,
471
- ref_dict=self.conds.s3,
472
- )
473
-
474
  return wav[0].cpu().numpy()
475
 
476
  def generate(
@@ -481,8 +633,9 @@ class Viterbox:
481
  exaggeration: float = 0.5,
482
  cfg_weight: float = 0.5,
483
  temperature: float = 0.8,
484
- top_p: float = 0.9,
485
- repetition_penalty: float = 1.2,
 
486
  split_sentences: bool = True,
487
  crossfade_ms: int = 50,
488
  sentence_pause_ms: int = 500,
@@ -544,8 +697,9 @@ class Viterbox:
544
  repetition_penalty=repetition_penalty,
545
  )
546
 
547
- # Trim silence from each segment (use less aggressive threshold)
548
- audio_np = trim_silence(audio_np, self.sr, top_db=20)
 
549
 
550
  # Apply fade-out to prevent click at end of each segment
551
  audio_np = apply_fade_out(audio_np, self.sr, fade_duration=0.01) # 10ms fade-out
 
36
  WAVS_DIR = Path("wavs")
37
 
38
 
39
+ # Global VAD model
40
+ _VAD_MODEL = None
41
+ _VAD_UTILS = None
42
+
43
+
44
+ def get_vad_model():
45
+ """Load Silero VAD model (singleton)"""
46
+ global _VAD_MODEL, _VAD_UTILS
47
+ if _VAD_MODEL is None:
48
+ try:
49
+ # Load from torch hub - will be cached
50
+ model, utils = torch.hub.load(
51
+ repo_or_dir='snakers4/silero-vad',
52
+ model='silero_vad',
53
+ force_reload=False,
54
+ trust_repo=True,
55
+ verbose=False
56
+ )
57
+ _VAD_MODEL = model
58
+ _VAD_UTILS = utils
59
+ except Exception as e:
60
+ print(f"⚠️ Could not load Silero VAD: {e}")
61
+ return None, None
62
+ return _VAD_MODEL, _VAD_UTILS
63
+
64
+
65
  def get_random_voice() -> Optional[Path]:
66
  """Get a random voice file from wavs folder"""
67
  if WAVS_DIR.exists():
 
72
  return None
73
 
74
 
75
+ def punc_norm(text: str) -> str:
76
+ """
77
+ Quick cleanup func for punctuation from LLMs or
78
+ containing chars not seen often in the dataset
79
+ """
80
+ if len(text) == 0:
81
+ return "You need to add some text for me to talk."
82
+
83
+ # Capitalise first letter
84
+ if len(text) > 0 and text[0].islower():
85
+ text = text[0].upper() + text[1:]
86
+
87
+ # Remove multiple space chars
88
+ text = " ".join(text.split())
89
+
90
+ # Replace uncommon/llm punc
91
+ punc_to_replace = [
92
+ ("...", ", "),
93
+ ("…", ", "),
94
+ (":", ","),
95
+ (" - ", ", "),
96
+ (";", ", "),
97
+ ("—", "-"),
98
+ ("–", "-"),
99
+ (" ,", ","),
100
+ ('"', '"'),
101
+ ("'", "'"),
102
+ ]
103
+ for old_char_sequence, new_char in punc_to_replace:
104
+ text = text.replace(old_char_sequence, new_char)
105
+
106
+ # Add full stop if no ending punc
107
+ text = text.rstrip(" ")
108
+ sentence_enders = {".", "!", "?", "-", ",", "、", ",", "。", "?", "!"}
109
+ if not any(text.endswith(p) for p in sentence_enders):
110
+ text += "."
111
+
112
+ return text
113
+
114
+
115
  def normalize_text(text: str, language: str = "vi") -> str:
116
  """Normalize Vietnamese text (numbers, abbreviations, etc.)"""
117
  if language == "vi" and HAS_VINORM and _normalizer is not None:
 
148
 
149
 
150
  def trim_silence(audio: np.ndarray, sr: int, top_db: int = 30) -> np.ndarray:
151
+ """Legacy trim silence (energy based)."""
152
  trimmed, _ = librosa.effects.trim(audio, top_db=top_db)
153
  return trimmed
154
 
155
 
156
+ def vad_trim(audio: np.ndarray, sr: int, margin_s: float = 0.01) -> np.ndarray:
157
+ """
158
+ Trim audio using Silero VAD to strictly keep only speech.
159
+
160
+ Args:
161
+ audio: Audio array (numpy)
162
+ sr: Sample rate
163
+ margin_s: Margin to keep after speech ends (seconds)
164
+ """
165
+ if len(audio) == 0:
166
+ return audio
167
+
168
+ model, utils = get_vad_model()
169
+ if model is None:
170
+ return trim_silence(audio, sr, top_db=20)
171
+
172
+ (get_speech_timestamps, _, read_audio, *_) = utils
173
+
174
+ # Prepare audio for VAD (must be float32)
175
+ wav = torch.tensor(audio, dtype=torch.float32)
176
+
177
+ # If sampling rate is not 8k or 16k, we might need resample for VAD?
178
+ # Silero supports 8000 or 16000 directly usually, but newer versions handle others.
179
+ # We will trust utils to handle or just pass as is (Silero supports 16k best).
180
+
181
+ # Actually Silero expects simple tensor. Let's try direct.
182
+ # Note: Silero often works best at 16k.
183
+
184
+ try:
185
+ # Get speech timestamps
186
+ # VAD typically expects 16000 sr. Let's resample strictly for detection if needed
187
+ # but let's try direct first. If sr is 24000, silero might warn.
188
+ # Safe bet: resample local copy for detection
189
+
190
+ vad_sr = 16000
191
+ if sr != vad_sr:
192
+ # Quick resample for detection only
193
+ wav_16k = librosa.resample(audio, orig_sr=sr, target_sr=vad_sr)
194
+ wav_tensor = torch.tensor(wav_16k, dtype=torch.float32)
195
+ else:
196
+ wav_tensor = wav
197
+
198
+ # Use VAD parameters
199
+ timestamps = get_speech_timestamps(
200
+ wav_tensor,
201
+ model,
202
+ sampling_rate=vad_sr,
203
+ threshold=0.35, # Relax threshold as we fixed the root cause
204
+ min_speech_duration_ms=250,
205
+ min_silence_duration_ms=100
206
+ )
207
+
208
+ if not timestamps:
209
+ # No speech detected? Fallback to mild energy trim or return as is?
210
+ # Sometimes VAD misses breathy endings. Let's fallback to energy trim
211
+ return trim_silence(audio, sr, top_db=25)
212
+
213
+ # Get end of last speech chunk
214
+ last_end_sample_16k = timestamps[-1]['end']
215
+
216
+ # Convert back to original sample rate
217
+ last_end_sample = int(last_end_sample_16k * (sr / vad_sr))
218
+
219
+ # Add margin
220
+ margin_samples = int(margin_s * sr)
221
+ cut_point = last_end_sample + margin_samples
222
+
223
+ # Don't cut beyond length
224
+ cut_point = min(cut_point, len(audio))
225
+
226
+ # Trim
227
+ return audio[:cut_point]
228
+
229
+ except Exception as e:
230
+ print(f"⚠️ VAD Error: {e}")
231
+ return trim_silence(audio, sr, top_db=20)
232
+
233
+
234
  def apply_fade_out(audio: np.ndarray, sr: int, fade_duration: float = 0.01) -> np.ndarray:
235
  """
236
  Apply smooth fade-out to prevent click artifacts at the end of audio.
 
575
  top_p: float,
576
  repetition_penalty: float,
577
  ) -> np.ndarray:
578
+ # Normalize and ensure text ends with punctuation (crucial for T3)
579
+ text = punc_norm(text)
580
+
581
  # Tokenize text with language prefix
582
  text_tokens = self.tokenizer.text_to_tokens(text, language_id=language).to(self.device)
583
 
 
594
  use_autocast = self.device in ['cuda', 'mps']
595
  device_type = 'cuda' if self.device == 'cuda' else 'mps'
596
 
597
+ with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=torch.float16, enabled=(self.device==use_autocast)):
598
  # Generate speech tokens with T3
599
  speech_tokens = self.t3.inference(
600
  t3_cond=self.conds.t3,
601
  text_tokens=text_tokens,
602
+ max_new_tokens=1000,
603
  temperature=temperature,
604
  cfg_weight=cfg_weight,
605
  repetition_penalty=repetition_penalty,
 
609
  # Extract only the conditional batch and filter invalid tokens
610
  speech_tokens = speech_tokens[0]
611
  speech_tokens = drop_invalid_tokens(speech_tokens)
612
+
613
+ # FIX (Root Cause): Remove the last token which often contains noise/transients
614
+ # causing click artifacts in S3 generation.
615
+ if len(speech_tokens) > 1:
616
+ speech_tokens = speech_tokens[:-1]
617
+
618
  speech_tokens = speech_tokens.to(self.device)
619
 
620
+ # Generate waveform with S3Gen
621
+ wav, _ = self.s3gen.inference(
622
+ speech_tokens=speech_tokens,
623
+ ref_dict=self.conds.s3,
624
+ )
625
+
626
  return wav[0].cpu().numpy()
627
 
628
  def generate(
 
633
  exaggeration: float = 0.5,
634
  cfg_weight: float = 0.5,
635
  temperature: float = 0.8,
636
+
637
+ top_p: float = 1.0,
638
+ repetition_penalty: float = 2.0,
639
  split_sentences: bool = True,
640
  crossfade_ms: int = 50,
641
  sentence_pause_ms: int = 500,
 
697
  repetition_penalty=repetition_penalty,
698
  )
699
 
700
+ # Trim silence using VAD (more precise endpointing)
701
+ # Keep margin reasonable (50ms) as we prevent clicks at generation level now
702
+ audio_np = vad_trim(audio_np, self.sr, margin_s=0.05)
703
 
704
  # Apply fade-out to prevent click at end of each segment
705
  audio_np = apply_fade_out(audio_np, self.sr, fade_duration=0.01) # 10ms fade-out