mehdi999 commited on
Commit
6b8706f
·
1 Parent(s): 3d734f0

added watch

Browse files
Files changed (1) hide show
  1. app.py +89 -52
app.py CHANGED
@@ -34,14 +34,18 @@ _sampling_rate = 24000
34
 
35
 
36
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
37
- s = (s or "").strip().lower()
38
  try:
39
  import re
40
  from num2words import num2words
41
-
42
- def repl(m): return num2words(int(m.group()), lang=lang_hint)
 
 
 
43
  s = re.sub(r"\d+", repl, s)
44
  except Exception:
 
45
  pass
46
  return s
47
 
@@ -50,22 +54,49 @@ def _load_model(device: str = "cuda"):
50
  global _pardi, _sampling_rate
51
  if _pardi is None:
52
  _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
 
53
  _sampling_rate = getattr(_pardi, "sampling_rate", 24000)
54
- print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
55
  return _pardi
56
 
57
 
58
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
59
- arr = arr.astype(np.float32)
60
  if arr.ndim == 2:
61
  arr = arr.mean(axis=1)
62
- return arr
63
 
64
 
65
- @spaces.GPU(duration=200) # 200s pour les autres users (peut être augmenté si besoin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def synthesize(
67
  text: str,
68
  debug: bool,
 
69
  ref_audio,
70
  ref_text: str,
71
  steps: int,
@@ -91,17 +122,18 @@ def synthesize(
91
 
92
  def maybe_timeout_checkpoint(stage: str):
93
  dur = time.perf_counter() - t0
94
- print(f"[debug] stage={stage} t={dur:.2f}s")
95
  if dur > MAX_WALLTIME_S:
96
  raise TimeoutError(f"Watchdog: dépassement {dur:.1f}s avant kill ZeroGPU (étape: {stage})")
97
 
98
  with redirect_stdout(logbuf), redirect_stderr(logbuf):
99
  try:
 
100
  progress(0.02, desc="Init")
101
  device = "cuda" if torch.cuda.is_available() else "cpu"
102
  torch.manual_seed(int(seed))
103
 
104
- # Pour des traces CUDA synchrones (erreurs au bon endroit)
105
  os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
106
 
107
  maybe_timeout_checkpoint("load_model")
@@ -114,26 +146,10 @@ def synthesize(
114
  progress(0.12, desc="Préparation du texte")
115
  txt = _normalize_text(text, lang_hint=lang_hint)
116
 
117
- # Clamp pour limiter la durée
118
  steps = int(min(max(1, steps), 16))
119
  max_seq_len = int(min(max(50, max_seq_len), 600))
120
 
121
- progress(0.16, desc="Paramètres sampling")
122
- # IMPORTANT : signature de VelocityHeadSamplingParams
123
- try:
124
- vel_params = VelocityHeadSamplingParams(
125
- cfg_ref=float(cfg_ref),
126
- cfg=float(cfg),
127
- num_steps=int(steps)
128
- )
129
- except TypeError:
130
- vel_params = VelocityHeadSamplingParams(
131
- cfg_ref=float(cfg_ref),
132
- cfg=float(cfg),
133
- num_steps=int(steps),
134
- temperature=float(temperature)
135
- )
136
-
137
  # Prefix optionnel
138
  maybe_timeout_checkpoint("prefix")
139
  progress(0.22, desc="Prefix (optionnel)")
@@ -143,63 +159,84 @@ def synthesize(
143
  wav, sr = sf.read(ref_audio)
144
  else:
145
  sr, wav = ref_audio
146
- wav = _to_mono_float32(np.array(wav))
147
  wav_t = torch.from_numpy(wav).to(device)
148
- import torchaudio
149
- if sr != pardi.sampling_rate:
150
- wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
 
 
 
151
  wav_t = wav_t.unsqueeze(0)
152
  with torch.inference_mode():
153
  prefix_tokens = pardi.patchvae.encode(wav_t)
154
  prefix = (ref_text or "", prefix_tokens[0])
155
 
156
- print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
157
  maybe_timeout_checkpoint("tts_start")
158
  progress(0.28, desc="Synthèse…")
159
 
160
  if device == "cuda":
161
  torch.cuda.synchronize()
162
 
 
163
  with torch.inference_mode():
164
- # Pas de cache envoyé (GLA “safe” côté modèle)
165
- wavs, _ = pardi.text_to_speech(
166
- [txt],
167
- prefix,
168
- max_seq_len=int(max_seq_len),
169
- velocity_head_sampling_params=vel_params,
170
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if device == "cuda":
173
  torch.cuda.synchronize()
174
  progress(0.96, desc="Finalisation")
175
 
176
- wav = wavs[0].detach().cpu().numpy()
177
  logs = logbuf.getvalue() if debug else ""
178
- print(f"[debug] synthesize walltime = {time.perf_counter()-t0:.2f}s")
179
  return (_sampling_rate, wav), logs
180
 
181
  except Exception as e:
182
- import traceback as _tb
183
  dur = time.perf_counter() - t0
184
- msg = f"{type(e).__name__}: {e}\\n\\n[walltime={dur:.1f}s]\\n"
185
- logs = msg + logbuf.getvalue() + "\\n" + _tb.format_exc()
186
- if debug:
187
- # On retourne la trace dans l’UI (textbox), sans lever d’exception Gradio
188
- return None, logs
189
- raise gr.Error(msg)
190
 
191
 
192
  def build_demo():
193
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
194
  gr.Markdown(
195
- "## Lina-speech (pardi-speech) – Démo TTS\\n"
196
- "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\\n"
197
- "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
 
198
  )
199
 
200
  with gr.Row():
201
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
202
  debug = gr.Checkbox(value=False, label="Mode debug (afficher la stacktrace)")
 
203
 
204
  with gr.Accordion("Prefix (optionnel)", open=False):
205
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
@@ -218,13 +255,13 @@ def build_demo():
218
 
219
  btn = gr.Button("Synthétiser")
220
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
221
- logs_box = gr.Textbox(label="Logs (debug)", lines=10)
222
 
223
  demo.queue(default_concurrency_limit=1, max_size=32)
224
 
225
  btn.click(
226
  fn=synthesize,
227
- inputs=[text, debug, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
228
  outputs=[out_audio, logs_box],
229
  )
230
  return demo
 
34
 
35
 
36
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
37
+ s = (s or "").strip()
38
  try:
39
  import re
40
  from num2words import num2words
41
+ def repl(m):
42
+ try:
43
+ return num2words(int(m.group()), lang=lang_hint)
44
+ except Exception:
45
+ return m.group()
46
  s = re.sub(r"\d+", repl, s)
47
  except Exception:
48
+ # pas de dépendance dure
49
  pass
50
  return s
51
 
 
54
  global _pardi, _sampling_rate
55
  if _pardi is None:
56
  _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
57
+ _pardi.eval()
58
  _sampling_rate = getattr(_pardi, "sampling_rate", 24000)
59
+ print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).", flush=True)
60
  return _pardi
61
 
62
 
63
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
64
+ arr = np.asarray(arr)
65
  if arr.ndim == 2:
66
  arr = arr.mean(axis=1)
67
+ return arr.astype(np.float32)
68
 
69
 
70
+ def _env_diag() -> str:
71
+ parts = []
72
+ try:
73
+ parts.append(f"torch: {torch.__version__}")
74
+ try:
75
+ import triton # type: ignore
76
+ parts.append(f"triton: {getattr(triton, '__version__', 'unknown')}")
77
+ except Exception as _e:
78
+ parts.append("triton: not importable")
79
+ parts.append(f"cuda.is_available: {torch.cuda.is_available()}")
80
+ if torch.cuda.is_available():
81
+ parts.append(f"cuda.device_count: {torch.cuda.device_count()}")
82
+ parts.append(f"cuda.current_device: {torch.cuda.current_device()}")
83
+ parts.append(f"cuda.get_device_name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
84
+ parts.append(f"cuda.version: {torch.version.cuda}")
85
+ try:
86
+ free, total = torch.cuda.mem_get_info()
87
+ parts.append(f"cuda.mem_free: {free/1e9:.2f} GB / total: {total/1e9:.2f} GB")
88
+ except Exception:
89
+ pass
90
+ except Exception as e:
91
+ parts.append(f"env_diag error: {e}")
92
+ return " | ".join(parts)
93
+
94
+
95
+ @spaces.GPU(duration=200) # 200s pour les autres users
96
  def synthesize(
97
  text: str,
98
  debug: bool,
99
+ adv_sampling: bool, # toggle "Sampling avancé (Velocity Head)"
100
  ref_audio,
101
  ref_text: str,
102
  steps: int,
 
122
 
123
  def maybe_timeout_checkpoint(stage: str):
124
  dur = time.perf_counter() - t0
125
+ print(f"[debug] stage={stage} t={dur:.2f}s", flush=True)
126
  if dur > MAX_WALLTIME_S:
127
  raise TimeoutError(f"Watchdog: dépassement {dur:.1f}s avant kill ZeroGPU (étape: {stage})")
128
 
129
  with redirect_stdout(logbuf), redirect_stderr(logbuf):
130
  try:
131
+ print("[env]", _env_diag(), flush=True)
132
  progress(0.02, desc="Init")
133
  device = "cuda" if torch.cuda.is_available() else "cpu"
134
  torch.manual_seed(int(seed))
135
 
136
+ # Traces CUDA synchrones (erreurs au bon endroit)
137
  os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
138
 
139
  maybe_timeout_checkpoint("load_model")
 
146
  progress(0.12, desc="Préparation du texte")
147
  txt = _normalize_text(text, lang_hint=lang_hint)
148
 
149
+ # Clamp pour limiter la durée (démo)
150
  steps = int(min(max(1, steps), 16))
151
  max_seq_len = int(min(max(50, max_seq_len), 600))
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Prefix optionnel
154
  maybe_timeout_checkpoint("prefix")
155
  progress(0.22, desc="Prefix (optionnel)")
 
159
  wav, sr = sf.read(ref_audio)
160
  else:
161
  sr, wav = ref_audio
162
+ wav = _to_mono_float32(wav)
163
  wav_t = torch.from_numpy(wav).to(device)
164
+ try:
165
+ import torchaudio
166
+ if sr != pardi.sampling_rate:
167
+ wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
168
+ except Exception as _e:
169
+ print("⚠️ torchaudio not available for resample; using original SR:", sr, flush=True)
170
  wav_t = wav_t.unsqueeze(0)
171
  with torch.inference_mode():
172
  prefix_tokens = pardi.patchvae.encode(wav_t)
173
  prefix = (ref_text or "", prefix_tokens[0])
174
 
175
+ print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}", flush=True)
176
  maybe_timeout_checkpoint("tts_start")
177
  progress(0.28, desc="Synthèse…")
178
 
179
  if device == "cuda":
180
  torch.cuda.synchronize()
181
 
182
+ # ---- FAST PATH (comme le notebook): sans VelocityHead par défaut ----
183
  with torch.inference_mode():
184
+ if adv_sampling:
185
+ # Mode avancé: on passe VelocityHeadSamplingParams
186
+ try:
187
+ vel_params = VelocityHeadSamplingParams(
188
+ cfg_ref=float(cfg_ref),
189
+ cfg=float(cfg),
190
+ num_steps=int(steps)
191
+ )
192
+ except TypeError:
193
+ vel_params = VelocityHeadSamplingParams(
194
+ cfg_ref=float(cfg_ref),
195
+ cfg=float(cfg),
196
+ num_steps=int(steps),
197
+ temperature=float(temperature)
198
+ )
199
+ wavs, _ = pardi.text_to_speech(
200
+ [txt], prefix, max_seq_len=int(max_seq_len),
201
+ velocity_head_sampling_params=vel_params
202
+ )
203
+ else:
204
+ # Fast path (notebook)
205
+ wavs, _ = pardi.text_to_speech(
206
+ [txt], prefix, max_seq_len=int(max_seq_len)
207
+ )
208
+ # --------------------------------------------------------------------
209
 
210
  if device == "cuda":
211
  torch.cuda.synchronize()
212
  progress(0.96, desc="Finalisation")
213
 
214
+ wav = wavs[0].detach().cpu().numpy().astype(np.float32)
215
  logs = logbuf.getvalue() if debug else ""
216
+ print(f"[debug] synthesize walltime = {time.perf_counter()-t0:.2f}s", flush=True)
217
  return (_sampling_rate, wav), logs
218
 
219
  except Exception as e:
 
220
  dur = time.perf_counter() - t0
221
+ msg = f"{type(e).__name__}: {e}\n\n[walltime={dur:.1f}s]\n"
222
+ logs = msg + logbuf.getvalue() + "\n" + traceback.format_exc()
223
+ # 👉 Toujours renvoyer les logs dans l'UI, même si debug = False
224
+ return None, logs
 
 
225
 
226
 
227
  def build_demo():
228
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
229
  gr.Markdown(
230
+ "## Lina-speech (pardi-speech) – Démo TTS\n"
231
+ "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
232
+ "Par défaut, le chemin **rapide** (comme dans le notebook) est utilisé. "
233
+ "Active **Sampling avancé** pour passer par Velocity Head."
234
  )
235
 
236
  with gr.Row():
237
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
238
  debug = gr.Checkbox(value=False, label="Mode debug (afficher la stacktrace)")
239
+ adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
240
 
241
  with gr.Accordion("Prefix (optionnel)", open=False):
242
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
 
255
 
256
  btn = gr.Button("Synthétiser")
257
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
258
+ logs_box = gr.Textbox(label="Logs (debug)", lines=14)
259
 
260
  demo.queue(default_concurrency_limit=1, max_size=32)
261
 
262
  btn.click(
263
  fn=synthesize,
264
+ inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
265
  outputs=[out_audio, logs_box],
266
  )
267
  return demo