File size: 12,448 Bytes
a175cfa
fd1f480
 
 
 
 
 
 
a175cfa
 
 
fd1f480
a175cfa
fd1f480
 
 
 
 
 
 
 
 
 
 
a175cfa
 
 
 
 
fd1f480
 
 
 
 
 
 
 
 
 
 
 
a175cfa
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
 
a175cfa
fd1f480
 
a175cfa
 
fd1f480
a175cfa
fd1f480
a175cfa
fd1f480
 
 
 
 
 
a175cfa
 
 
 
 
fd1f480
a175cfa
 
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a175cfa
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a175cfa
 
fd1f480
 
a175cfa
 
 
 
 
 
 
 
fd1f480
a175cfa
fd1f480
 
 
 
 
 
 
a175cfa
fd1f480
 
 
 
 
 
 
a175cfa
fd1f480
 
 
a175cfa
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a175cfa
fd1f480
 
a175cfa
fd1f480
 
a175cfa
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a175cfa
fd1f480
 
 
 
 
 
 
 
 
 
a175cfa
fd1f480
 
a175cfa
fd1f480
 
 
 
 
a175cfa
 
 
fd1f480
 
 
a175cfa
 
 
 
 
fd1f480
a175cfa
 
 
 
 
 
 
 
fd1f480
 
 
 
 
a175cfa
 
 
fd1f480
a175cfa
 
 
 
fd1f480
 
 
a175cfa
 
 
 
fd1f480
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
import re
import json
import sys
import time
import threading
import traceback

import gradio as gr
import numpy as np
import soundfile as sf
import torch
import spaces
from huggingface_hub import login, snapshot_download

# --------- Environnement / stabilité ----------
os.environ.setdefault("FLA_CONV_BACKEND", "torch")   # éviter les kernels Triton
os.environ.setdefault("FLA_USE_FAST_OPS", "0")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
torch.backends.cuda.matmul.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

from pardi_speech import PardiSpeech, VelocityHeadSamplingParams  # présent dans ce repo

MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
HF_TOKEN = os.environ.get("HF_TOKEN")

# --------- Cache global (préchargement au démarrage) ----------
_MODEL = {"pardi": None, "sr": 24000, "err": None, "logs": [], "thread": None}

def _log(msg: str):
    _MODEL["logs"].append(str(msg))
    # borne la taille
    if len(_MODEL["logs"]) > 2000:
        _MODEL["logs"] = _MODEL["logs"][-2000:]

def _env_diag() -> str:
    parts = []
    try:
        parts.append(f"torch={torch.__version__}")
        try:
            import triton  # type: ignore
            parts.append(f"triton={getattr(triton, '__version__', 'unknown')}")
        except Exception:
            parts.append("triton=not_importable")
        parts.append(f"cuda.is_available={torch.cuda.is_available()}")
        if torch.cuda.is_available():
            parts.append(f"cuda.version={torch.version.cuda}")
            try:
                free, total = torch.cuda.mem_get_info()
                parts.append(f"mem_free={free/1e9:.2f}GB/{total/1e9:.2f}GB")
            except Exception:
                pass
    except Exception as e:
        parts.append(f"env_diag_error={e}")
    return " | ".join(parts)

def _normalize_text(s: str, lang_hint: str = "fr") -> str:
    s = (s or "").strip()
    try:
        import re as _re
        from num2words import num2words
        def repl(m):
            try:
                return num2words(int(m.group()), lang=lang_hint)
            except Exception:
                return m.group()
        s = _re.sub(r"\d+", repl, s)
    except Exception:
        pass
    return s

def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr)
    if arr.ndim == 2:
        arr = arr.mean(axis=1)
    return arr.astype(np.float32)

def _extract_repo_ids_from_config(config_path: str):
    repo_ids = set()
    preview = None
    try:
        with open(config_path, "r", encoding="utf-8") as f:
            cfg = json.load(f)
        pattern = re.compile(r"^[\w\-]+\/[\w\.\-]+$")  # org/name
        def rec(obj):
            if isinstance(obj, dict):
                for v in obj.values(): rec(v)
            elif isinstance(obj, list):
                for v in obj: rec(v)
            elif isinstance(obj, str):
                if pattern.match(obj): repo_ids.add(obj)
        rec(cfg)
        try:
            subset_keys = list(cfg)[:5] if isinstance(cfg, dict) else []
            preview = json.dumps({k: cfg[k] for k in subset_keys}, ensure_ascii=False)[:600]
        except Exception:
            pass
    except Exception:
        pass
    return sorted(repo_ids), preview

def _prefetch_and_load_cpu():
    """Exécuté dans un thread au démarrage du Space (hors worker GPU)."""
    try:
        _log("[prefetch] snapshot_download (main)...")
        local_dir = snapshot_download(
            repo_id=MODEL_REPO_ID,
            token=HF_TOKEN,
            local_dir=None,
            local_files_only=False,
        )
        _log(f"[prefetch] main done -> {local_dir}")

        cfg_path = os.path.join(local_dir, "config.json")
        nested, cfg_preview = _extract_repo_ids_from_config(cfg_path)
        if cfg_preview:
            _log(f"[config] preview: {cfg_preview}")
        for rid in nested:
            if rid == MODEL_REPO_ID:
                continue
            _log(f"[prefetch] nested repo: {rid} ...")
            snapshot_download(repo_id=rid, token=HF_TOKEN, local_dir=None, local_files_only=False)
            _log(f"[prefetch] nested repo: {rid} done")

        # Forcer offline pendant le vrai chargement
        old_off = os.environ.get("HF_HUB_OFFLINE")
        os.environ["HF_HUB_OFFLINE"] = "1"
        try:
            _log("[load] from_pretrained(map_location='cpu')...")
            m = PardiSpeech.from_pretrained(local_dir, map_location="cpu")
            m.eval()
            _MODEL["pardi"] = m
            _MODEL["sr"] = getattr(m, "sampling_rate", 24000)
            _log(f"[load] cpu OK (sr={_MODEL['sr']})")
        finally:
            if old_off is None:
                os.environ.pop("HF_HUB_OFFLINE", None)
            else:
                os.environ["HF_HUB_OFFLINE"] = old_off

    except BaseException as e:
        _MODEL["err"] = e
        _log(f"[EXC@preload] {type(e).__name__}: {e}")
        _log(traceback.format_exc())

# Lance le préchargement (hors GPU) dès l’import
if _MODEL["thread"] is None:
    _MODEL["thread"] = threading.Thread(target=_prefetch_and_load_cpu, daemon=True)
    _MODEL["thread"].start()

def _move_to_cuda_if_available(m, logs_acc):
    def L(msg): logs_acc.append(str(msg))
    if torch.cuda.is_available():
        L("[move] moving model to cuda...")
        try:
            m = m.to("cuda")  # type: ignore[attr-defined]
            L("[move] cuda OK")
        except Exception as e:
            L(f"[move] .to('cuda') failed: {e}. Keeping on CPU.")
    else:
        L("[move] cuda not available, keep CPU")
    return m

# --------- UI callback (GPU) ----------
@spaces.GPU(duration=200)
def synthesize(
    text: str,
    debug: bool,
    adv_sampling: bool,   # Velocity Head sampling
    ref_audio,
    ref_text: str,
    steps: int,
    cfg: float,
    cfg_ref: float,
    temperature: float,
    max_seq_len: int,
    seed: int,
    lang_hint: str,
):
    logs = []
    def LOG(msg: str):
        logs.append(str(msg))
        joined = "\n".join(logs + _MODEL["logs"][-50:])  # mêle quelques logs de préchargement
        if len(joined) > 12000:
            joined = joined[-12000:]
        return joined

    try:
        if HF_TOKEN:
            try:
                login(token=HF_TOKEN)
                yield None, LOG("✅ HF login ok")
            except Exception as e:
                yield None, LOG(f"⚠️ HF login failed: {e}")

        yield None, LOG("[env] " + _env_diag())
        torch.manual_seed(int(seed))
        os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")

        # Si le modèle n’est pas encore prêt, on attend jusqu’à 180s max ici
        t0 = time.perf_counter()
        while _MODEL["pardi"] is None and _MODEL["err"] is None:
            elapsed = time.perf_counter() - t0
            yield None, LOG(f"[init] still loading on CPU… {elapsed:.1f}s")
            if elapsed > 180:
                # dump de la stack du thread de préchargement pour debug
                tid = _MODEL["thread"].ident if _MODEL["thread"] else None
                if tid is not None:
                    frame = sys._current_frames().get(tid)
                    if frame is not None:
                        stack_txt = "".join(traceback.format_stack(frame))
                        yield None, LOG("[stack-final]\n" + stack_txt)
                raise TimeoutError("Preload timeout (>180s)")
            time.sleep(2.0)

        if _MODEL["err"]:
            raise _MODEL["err"]

        pardi = _MODEL["pardi"]
        sr_out = _MODEL["sr"]

        # Déplacement vers CUDA si possible
        pardi = _move_to_cuda_if_available(pardi, logs)
        yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={sr_out}")

        # ---- Texte + prefix optionnel ----
        txt = _normalize_text(text or "", lang_hint=lang_hint)
        yield None, LOG(f"[text] {txt[:120]}{'...' if len(txt) > 120 else ''}")

        steps = int(min(max(1, int(steps)), 16))
        max_seq_len = int(min(max(50, int(max_seq_len)), 600))

        prefix = None
        if ref_audio is not None:
            yield None, LOG("[prefix] encoding reference audio...")
            if isinstance(ref_audio, str):
                wav, sr = sf.read(ref_audio)
            else:
                sr, wav = ref_audio
            wav = _to_mono_float32(wav)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            wav_t = torch.from_numpy(wav).to(device).unsqueeze(0)
            with torch.inference_mode():
                prefix_tokens = pardi.patchvae.encode(wav_t)  # type: ignore[attr-defined]
            prefix = (ref_text or "", prefix_tokens[0])
            yield None, LOG("[prefix] done.")

        yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, "
                        f"T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")

        # ---- Chemin rapide (comme le notebook) ----
        with torch.inference_mode():
            if adv_sampling:
                try:
                    vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps))
                except TypeError:
                    vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg),
                                                         num_steps=int(steps), temperature=float(temperature))
                wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len),
                                               velocity_head_sampling_params=vparams)
            else:
                wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len))

        wav = wavs[0].detach().cpu().numpy().astype(np.float32)
        yield (sr_out, wav), LOG("[ok] done.")

    except Exception as e:
        tb = traceback.format_exc()
        yield None, LOG(f"[EXC] {type(e).__name__}: {e}\n{tb}")

# --------- UI ----------
def build_demo():
    with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
        gr.Markdown(
            "### Lina-speech (pardi-speech) – Démo TTS\n"
            "Génère de l'audio à partir de texte, avec ou sans prefix (audio de référence).\n"
            "Chemin rapide par défaut (comme le notebook)."
        )
        with gr.Row():
            text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
        with gr.Accordion("Prefix (optionnel)", open=False):
            ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
            ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
        with gr.Accordion("Options avancées", open=False):
            with gr.Row():
                steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
                cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)")
                cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)")
            with gr.Row():
                temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
                max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
                seed = gr.Number(value=0, precision=0, label="Seed")
                lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
        with gr.Row():
            debug = gr.Checkbox(value=False, label="Mode debug")
            adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")

        btn = gr.Button("Synthétiser")
        out_audio = gr.Audio(label="Sortie audio", type="numpy")
        logs_box = gr.Textbox(label="Logs (live)", lines=28)

        demo.queue(default_concurrency_limit=1, max_size=32)
        btn.click(
            fn=synthesize,
            inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
            outputs=[out_audio, logs_box],
            api_name="synthesize",
        )
    return demo

if __name__ == "__main__":
    build_demo().launch(ssr_mode=False)
# retrigger 2025-10-30T15:17:49+01:00
# retrigger 2025-10-30T16:37:47+01:00