mehdi999 commited on
Commit
3d734f0
·
1 Parent(s): 7208504

added watch

Browse files
Files changed (1) hide show
  1. app.py +137 -70
app.py CHANGED
@@ -8,6 +8,14 @@ import spaces
8
  # FLA: forcer les convolutions en backend PyTorch (pas de Triton)
9
  os.environ.setdefault("FLA_CONV_BACKEND", "torch")
10
  os.environ.setdefault("FLA_USE_FAST_OPS", "0")
 
 
 
 
 
 
 
 
11
  from huggingface_hub import login
12
  from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
13
 
@@ -24,17 +32,20 @@ if HF_TOKEN:
24
  _pardi = None
25
  _sampling_rate = 24000
26
 
 
27
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
28
  s = (s or "").strip().lower()
29
  try:
30
  import re
31
  from num2words import num2words
 
32
  def repl(m): return num2words(int(m.group()), lang=lang_hint)
33
  s = re.sub(r"\d+", repl, s)
34
  except Exception:
35
  pass
36
  return s
37
 
 
38
  def _load_model(device: str = "cuda"):
39
  global _pardi, _sampling_rate
40
  if _pardi is None:
@@ -43,15 +54,18 @@ def _load_model(device: str = "cuda"):
43
  print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
44
  return _pardi
45
 
 
46
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
47
  arr = arr.astype(np.float32)
48
  if arr.ndim == 2:
49
  arr = arr.mean(axis=1)
50
  return arr
51
 
52
- @spaces.GPU(duration=200)
 
53
  def synthesize(
54
  text: str,
 
55
  ref_audio,
56
  ref_text: str,
57
  steps: int,
@@ -60,83 +74,137 @@ def synthesize(
60
  temperature: float,
61
  max_seq_len: int,
62
  seed: int,
63
- lang_hint: str
 
64
  ):
65
- device = "cuda" if torch.cuda.is_available() else "cpu"
66
- torch.manual_seed(int(seed))
67
-
68
- pardi = _load_model(device)
69
- txt = _normalize_text(text, lang_hint=lang_hint)
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
- # --- IMPORTANT : signature de VelocityHeadSamplingParams ---
74
- # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'.
75
- # On essaie d’abord sans temperature, puis fallback si la classe en accepte une.
76
- try:
77
- vel_params = VelocityHeadSamplingParams(
78
- cfg_ref=float(cfg_ref),
79
- cfg=float(cfg),
80
- num_steps=int(steps)
81
- )
82
- except TypeError:
83
- vel_params = VelocityHeadSamplingParams(
84
- cfg_ref=float(cfg_ref),
85
- cfg=float(cfg),
86
- num_steps=int(steps),
87
- temperature=float(temperature)
88
- )
89
-
90
- # Prefix optionnel
91
- prefix = None
92
- if ref_audio is not None:
93
- if isinstance(ref_audio, str):
94
- wav, sr = sf.read(ref_audio)
95
- else:
96
- sr, wav = ref_audio
97
- wav = _to_mono_float32(np.array(wav))
98
- wav_t = torch.from_numpy(wav).to(device)
99
- import torchaudio
100
- if sr != pardi.sampling_rate:
101
- wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
102
- wav_t = wav_t.unsqueeze(0)
103
- with torch.inference_mode():
104
- prefix_tokens = pardi.patchvae.encode(wav_t)
105
- prefix = (ref_text or "", prefix_tokens[0])
106
-
107
- print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
108
-
109
- try:
110
- with torch.inference_mode():
111
- wavs, _ = pardi.text_to_speech(
112
- [txt],
113
- prefix,
114
- max_seq_len=int(max_seq_len),
115
- velocity_head_sampling_params=vel_params,
116
-
117
- )
118
- except Exception as e:
119
- import traceback, sys
120
- print("❌ text_to_speech failed:", e, file=sys.stderr)
121
- traceback.print_exc()
122
- raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}")
123
-
124
- wav = wavs[0].detach().cpu().numpy()
125
- return (_sampling_rate, wav)
126
-
127
  def build_demo():
128
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
129
  gr.Markdown(
130
- "## Lina-speech (pardi-speech) – Démo TTS\n"
131
- "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
132
  "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
133
  )
134
 
135
  with gr.Row():
136
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
 
 
137
  with gr.Accordion("Prefix (optionnel)", open=False):
138
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
139
- ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
 
140
  with gr.Accordion("Options avancées", open=False):
141
  with gr.Row():
142
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
@@ -150,19 +218,18 @@ def build_demo():
150
 
151
  btn = gr.Button("Synthétiser")
152
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
 
153
 
154
  demo.queue(default_concurrency_limit=1, max_size=32)
155
 
156
  btn.click(
157
  fn=synthesize,
158
- inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
159
- outputs=[out_audio]
160
  )
161
  return demo
162
 
 
163
  if __name__ == "__main__":
164
  demo = build_demo()
165
  demo.launch()
166
- # retrigger 2025-10-29T16:27:55+01:00
167
- # retrigger 2025-10-29T17:44:57+01:00
168
- # retrigger 2025-10-29T18:59:12+01:00
 
8
  # FLA: forcer les convolutions en backend PyTorch (pas de Triton)
9
  os.environ.setdefault("FLA_CONV_BACKEND", "torch")
10
  os.environ.setdefault("FLA_USE_FAST_OPS", "0")
11
+
12
+ # Meilleure perf FP32 sur GPU compatibles
13
+ torch.backends.cuda.matmul.allow_tf32 = True
14
+ try:
15
+ torch.set_float32_matmul_precision("high")
16
+ except Exception:
17
+ pass
18
+
19
  from huggingface_hub import login
20
  from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
21
 
 
32
  _pardi = None
33
  _sampling_rate = 24000
34
 
35
+
36
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
37
  s = (s or "").strip().lower()
38
  try:
39
  import re
40
  from num2words import num2words
41
+
42
  def repl(m): return num2words(int(m.group()), lang=lang_hint)
43
  s = re.sub(r"\d+", repl, s)
44
  except Exception:
45
  pass
46
  return s
47
 
48
+
49
  def _load_model(device: str = "cuda"):
50
  global _pardi, _sampling_rate
51
  if _pardi is None:
 
54
  print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
55
  return _pardi
56
 
57
+
58
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
59
  arr = arr.astype(np.float32)
60
  if arr.ndim == 2:
61
  arr = arr.mean(axis=1)
62
  return arr
63
 
64
+
65
+ @spaces.GPU(duration=200) # 200s pour les autres users (peut être augmenté si besoin)
66
  def synthesize(
67
  text: str,
68
+ debug: bool,
69
  ref_audio,
70
  ref_text: str,
71
  steps: int,
 
74
  temperature: float,
75
  max_seq_len: int,
76
  seed: int,
77
+ lang_hint: str,
78
+ progress=gr.Progress(track_tqdm=True),
79
  ):
80
+ import io
81
+ import time
82
+ import traceback
83
+ from contextlib import redirect_stdout, redirect_stderr
84
+
85
+ # --- capture logs UI ---
86
+ logbuf = io.StringIO()
87
+ t0 = time.perf_counter()
88
+
89
+ # Watchdog: lève une erreur lisible avant un éventuel kill ZeroGPU
90
+ MAX_WALLTIME_S = 110
91
+
92
+ def maybe_timeout_checkpoint(stage: str):
93
+ dur = time.perf_counter() - t0
94
+ print(f"[debug] stage={stage} t={dur:.2f}s")
95
+ if dur > MAX_WALLTIME_S:
96
+ raise TimeoutError(f"Watchdog: dépassement {dur:.1f}s avant kill ZeroGPU (étape: {stage})")
97
+
98
+ with redirect_stdout(logbuf), redirect_stderr(logbuf):
99
+ try:
100
+ progress(0.02, desc="Init")
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+ torch.manual_seed(int(seed))
103
+
104
+ # Pour des traces CUDA synchrones (erreurs au bon endroit)
105
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
106
+
107
+ maybe_timeout_checkpoint("load_model")
108
+ progress(0.08, desc="Chargement du modèle")
109
+ pardi = _load_model(device)
110
+ if device == "cuda":
111
+ torch.cuda.synchronize()
112
+
113
+ maybe_timeout_checkpoint("normalize")
114
+ progress(0.12, desc="Préparation du texte")
115
+ txt = _normalize_text(text, lang_hint=lang_hint)
116
+
117
+ # Clamp pour limiter la durée
118
+ steps = int(min(max(1, steps), 16))
119
+ max_seq_len = int(min(max(50, max_seq_len), 600))
120
+
121
+ progress(0.16, desc="Paramètres sampling")
122
+ # IMPORTANT : signature de VelocityHeadSamplingParams
123
+ try:
124
+ vel_params = VelocityHeadSamplingParams(
125
+ cfg_ref=float(cfg_ref),
126
+ cfg=float(cfg),
127
+ num_steps=int(steps)
128
+ )
129
+ except TypeError:
130
+ vel_params = VelocityHeadSamplingParams(
131
+ cfg_ref=float(cfg_ref),
132
+ cfg=float(cfg),
133
+ num_steps=int(steps),
134
+ temperature=float(temperature)
135
+ )
136
+
137
+ # Prefix optionnel
138
+ maybe_timeout_checkpoint("prefix")
139
+ progress(0.22, desc="Prefix (optionnel)")
140
+ prefix = None
141
+ if ref_audio is not None:
142
+ if isinstance(ref_audio, str):
143
+ wav, sr = sf.read(ref_audio)
144
+ else:
145
+ sr, wav = ref_audio
146
+ wav = _to_mono_float32(np.array(wav))
147
+ wav_t = torch.from_numpy(wav).to(device)
148
+ import torchaudio
149
+ if sr != pardi.sampling_rate:
150
+ wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
151
+ wav_t = wav_t.unsqueeze(0)
152
+ with torch.inference_mode():
153
+ prefix_tokens = pardi.patchvae.encode(wav_t)
154
+ prefix = (ref_text or "", prefix_tokens[0])
155
+
156
+ print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
157
+ maybe_timeout_checkpoint("tts_start")
158
+ progress(0.28, desc="Synthèse…")
159
+
160
+ if device == "cuda":
161
+ torch.cuda.synchronize()
162
+
163
+ with torch.inference_mode():
164
+ # Pas de cache envoyé (GLA “safe” côté modèle)
165
+ wavs, _ = pardi.text_to_speech(
166
+ [txt],
167
+ prefix,
168
+ max_seq_len=int(max_seq_len),
169
+ velocity_head_sampling_params=vel_params,
170
+ )
171
+
172
+ if device == "cuda":
173
+ torch.cuda.synchronize()
174
+ progress(0.96, desc="Finalisation")
175
+
176
+ wav = wavs[0].detach().cpu().numpy()
177
+ logs = logbuf.getvalue() if debug else ""
178
+ print(f"[debug] synthesize walltime = {time.perf_counter()-t0:.2f}s")
179
+ return (_sampling_rate, wav), logs
180
+
181
+ except Exception as e:
182
+ import traceback as _tb
183
+ dur = time.perf_counter() - t0
184
+ msg = f"{type(e).__name__}: {e}\\n\\n[walltime={dur:.1f}s]\\n"
185
+ logs = msg + logbuf.getvalue() + "\\n" + _tb.format_exc()
186
+ if debug:
187
+ # On retourne la trace dans l’UI (textbox), sans lever d’exception Gradio
188
+ return None, logs
189
+ raise gr.Error(msg)
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def build_demo():
193
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
194
  gr.Markdown(
195
+ "## Lina-speech (pardi-speech) – Démo TTS\\n"
196
+ "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\\n"
197
  "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
198
  )
199
 
200
  with gr.Row():
201
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
202
+ debug = gr.Checkbox(value=False, label="Mode debug (afficher la stacktrace)")
203
+
204
  with gr.Accordion("Prefix (optionnel)", open=False):
205
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
206
+ ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
207
+
208
  with gr.Accordion("Options avancées", open=False):
209
  with gr.Row():
210
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
 
218
 
219
  btn = gr.Button("Synthétiser")
220
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
221
+ logs_box = gr.Textbox(label="Logs (debug)", lines=10)
222
 
223
  demo.queue(default_concurrency_limit=1, max_size=32)
224
 
225
  btn.click(
226
  fn=synthesize,
227
+ inputs=[text, debug, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
228
+ outputs=[out_audio, logs_box],
229
  )
230
  return demo
231
 
232
+
233
  if __name__ == "__main__":
234
  demo = build_demo()
235
  demo.launch()