mehdi999 commited on
Commit
a175cfa
·
1 Parent(s): bb32e4f

Force FLA mode=chunk to avoid Triton fused kernels on ZeroGPU

Browse files
Files changed (3) hide show
  1. app.py.bak +163 -0
  2. tts/model/simple_gla.py +1 -1
  3. tts/model/simple_gla.py.bak +295 -0
app.py.bak ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import soundfile as sf
6
+ import spaces
7
+
8
+ from huggingface_hub import login
9
+ from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
10
+
11
+ MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
12
+
13
+ HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ if HF_TOKEN:
15
+ try:
16
+ login(token=HF_TOKEN)
17
+ print("✅ Logged to Hugging Face Hub.")
18
+ except Exception as e:
19
+ print("⚠️ HF login failed:", e)
20
+
21
+ _pardi = None
22
+ _sampling_rate = 24000
23
+
24
+ def _normalize_text(s: str, lang_hint: str = "fr") -> str:
25
+ s = (s or "").strip().lower()
26
+ try:
27
+ import re
28
+ from num2words import num2words
29
+ def repl(m): return num2words(int(m.group()), lang=lang_hint)
30
+ s = re.sub(r"\d+", repl, s)
31
+ except Exception:
32
+ pass
33
+ return s
34
+
35
+ def _load_model(device: str = "cuda"):
36
+ global _pardi, _sampling_rate
37
+ if _pardi is None:
38
+ _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
39
+ _sampling_rate = getattr(_pardi, "sampling_rate", 24000)
40
+ print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
41
+ return _pardi
42
+
43
+ def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
44
+ arr = arr.astype(np.float32)
45
+ if arr.ndim == 2:
46
+ arr = arr.mean(axis=1)
47
+ return arr
48
+
49
+ @spaces.GPU(duration=120)
50
+ def synthesize(
51
+ text: str,
52
+ ref_audio,
53
+ ref_text: str,
54
+ steps: int,
55
+ cfg: float,
56
+ cfg_ref: float,
57
+ temperature: float,
58
+ max_seq_len: int,
59
+ seed: int,
60
+ lang_hint: str
61
+ ):
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ torch.manual_seed(int(seed))
64
+
65
+ pardi = _load_model(device)
66
+ txt = _normalize_text(text, lang_hint=lang_hint)
67
+
68
+ cache = pardi.tts.audio_decoder.init_cache(int(max_seq_len), device)
69
+
70
+ # --- IMPORTANT : signature de VelocityHeadSamplingParams ---
71
+ # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'.
72
+ # On essaie d’abord sans temperature, puis fallback si la classe en accepte une.
73
+ try:
74
+ vel_params = VelocityHeadSamplingParams(
75
+ cfg_ref=float(cfg_ref),
76
+ cfg=float(cfg),
77
+ num_steps=int(steps)
78
+ )
79
+ except TypeError:
80
+ vel_params = VelocityHeadSamplingParams(
81
+ cfg_ref=float(cfg_ref),
82
+ cfg=float(cfg),
83
+ num_steps=int(steps),
84
+ temperature=float(temperature)
85
+ )
86
+
87
+ # Prefix optionnel
88
+ prefix = None
89
+ if ref_audio is not None:
90
+ if isinstance(ref_audio, str):
91
+ wav, sr = sf.read(ref_audio)
92
+ else:
93
+ sr, wav = ref_audio
94
+ wav = _to_mono_float32(np.array(wav))
95
+ wav_t = torch.from_numpy(wav).to(device)
96
+ import torchaudio
97
+ if sr != pardi.sampling_rate:
98
+ wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
99
+ wav_t = wav_t.unsqueeze(0)
100
+ with torch.inference_mode():
101
+ prefix_tokens = pardi.patchvae.encode(wav_t)
102
+ prefix = (ref_text or "", prefix_tokens[0])
103
+
104
+ print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
105
+
106
+ try:
107
+ with torch.inference_mode():
108
+ wavs, _ = pardi.text_to_speech(
109
+ [txt],
110
+ prefix,
111
+ max_seq_len=int(max_seq_len),
112
+ velocity_head_sampling_params=vel_params,
113
+ cache=cache
114
+ )
115
+ except Exception as e:
116
+ import traceback, sys
117
+ print("❌ text_to_speech failed:", e, file=sys.stderr)
118
+ traceback.print_exc()
119
+ raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}")
120
+
121
+ wav = wavs[0].detach().cpu().numpy()
122
+ return (_sampling_rate, wav)
123
+
124
+ def build_demo():
125
+ with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
126
+ gr.Markdown(
127
+ "## Lina-speech (pardi-speech) – Démo TTS\n"
128
+ "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
129
+ "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
130
+ )
131
+
132
+ with gr.Row():
133
+ text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
134
+ with gr.Accordion("Prefix (optionnel)", open=False):
135
+ ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
136
+ ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
137
+ with gr.Accordion("Options avancées", open=False):
138
+ with gr.Row():
139
+ steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
140
+ cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)")
141
+ cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)")
142
+ with gr.Row():
143
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
144
+ max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
145
+ seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)")
146
+ lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
147
+
148
+ btn = gr.Button("Synthétiser")
149
+ out_audio = gr.Audio(label="Sortie audio", type="numpy")
150
+
151
+ demo.queue(default_concurrency_limit=1, max_size=32)
152
+
153
+ btn.click(
154
+ fn=synthesize,
155
+ inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
156
+ outputs=[out_audio]
157
+ )
158
+ return demo
159
+
160
+ if __name__ == "__main__":
161
+ demo = build_demo()
162
+ demo.launch()
163
+ # retrigger 2025-10-29T16:27:55+01:00
tts/model/simple_gla.py CHANGED
@@ -43,7 +43,7 @@ class SimpleGLABlock(nn.Module):
43
  ffn_expansion_factor: int,
44
  ):
45
  super().__init__()
46
- self.tmix = SimpleGatedLinearAttention(
47
  hidden_size=dim,
48
  num_heads=num_heads,
49
  layer_idx=layer_idx,
 
43
  ffn_expansion_factor: int,
44
  ):
45
  super().__init__()
46
+ self.tmix = SimpleGatedLinearAttention(mode='chunk',
47
  hidden_size=dim,
48
  num_heads=num_heads,
49
  layer_idx=layer_idx,
tts/model/simple_gla.py.bak ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ #simple-gla
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from fla.layers.simple_gla import SimpleGatedLinearAttention
7
+ from fla.models.utils import Cache
8
+ from sympy import num_digits
9
+ from torch import nn
10
+
11
+ from tts.layers.attention import CrossAttention
12
+ from tts.layers.ffn import SwiGLU
13
+
14
+ from .cache_utils import FLACache
15
+ from .config import SimpleGLADecoderConfig
16
+ from .registry import register_decoder
17
+ from .shortconv import ShortConvBlock
18
+
19
+ if "GRAD_CKPT" in os.environ:
20
+
21
+ def maybe_grad_ckpt(f):
22
+ def grad_ckpt_f(*args, **kwargs):
23
+ return torch.utils.checkpoint.checkpoint(
24
+ f, *args, **kwargs, use_reentrant=False
25
+ )
26
+
27
+ return grad_ckpt_f
28
+ else:
29
+
30
+ def maybe_grad_ckpt(f):
31
+ return f
32
+
33
+
34
+ class SimpleGLABlock(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ num_heads: int,
39
+ layer_idx: int,
40
+ expand_k: float,
41
+ expand_v: float,
42
+ use_short_conv: bool,
43
+ ffn_expansion_factor: int,
44
+ ):
45
+ super().__init__()
46
+ self.tmix = SimpleGatedLinearAttention(
47
+ hidden_size=dim,
48
+ num_heads=num_heads,
49
+ layer_idx=layer_idx,
50
+ )
51
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
52
+ self.norm1 = nn.LayerNorm(dim)
53
+ self.norm2 = nn.LayerNorm(dim)
54
+
55
+ def forward(
56
+ self,
57
+ x,
58
+ freqs: torch.Tensor | None = None,
59
+ text_freqs: torch.Tensor | None = None,
60
+ cache: Cache | None = None,
61
+ ):
62
+ # N’active le cache QUE s’il est utilisable (conv_state non nul)
63
+ use_cache_flag = isinstance(cache, dict) and cache.get("conv_state", None) not in (None, [])
64
+ pkv = cache if use_cache_flag else None
65
+
66
+ x = (
67
+ self.tmix(
68
+ self.norm1(x),
69
+ past_key_values=pkv,
70
+ use_cache=use_cache_flag,
71
+ )[0]
72
+ + x
73
+ )
74
+ x = self.cmix(self.norm2(x)) + x
75
+ return x
76
+
77
+
78
+ class DecoderBlockWithOptionalCrossAttention(nn.Module):
79
+ def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
80
+ super().__init__()
81
+
82
+ self.decoder_block = decoder_block
83
+ self.crossatt = crossatt
84
+
85
+ def forward(
86
+ self,
87
+ x: torch.Tensor,
88
+ encoder_output: torch.Tensor | None = None,
89
+ freqs: torch.Tensor | None = None,
90
+ text_freqs: torch.Tensor | None = None,
91
+ cache: Cache | None = None,
92
+ selfatt_mask: torch.Tensor | None = None,
93
+ crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
94
+ ) -> torch.Tensor:
95
+ x = self.decoder_block(
96
+ x,
97
+ freqs=freqs,
98
+ cache=cache,
99
+ )
100
+ if type(crossatt_mask) is list:
101
+ crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx]
102
+ if self.crossatt is not None:
103
+ x = x + self.crossatt(
104
+ x,
105
+ k=encoder_output,
106
+ text_freqs=text_freqs,
107
+ mask=crossatt_mask,
108
+ cache=cache,
109
+ )
110
+
111
+ return x
112
+
113
+
114
+ @register_decoder("simple_gla")
115
+ class SimpleGLADecoder(nn.Module):
116
+ config = SimpleGLADecoderConfig
117
+
118
+ def __init__(self, cfg: SimpleGLADecoderConfig):
119
+ super().__init__()
120
+
121
+ assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim"
122
+ assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, (
123
+ "at most one specialized cross-attention"
124
+ )
125
+
126
+ self.head_dim = cfg.dim // cfg.num_heads
127
+ self.num_heads = cfg.num_heads
128
+
129
+ def simple_gla_block(i):
130
+ conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
131
+ if i in conv_layers:
132
+ return ShortConvBlock(
133
+ dim=cfg.dim,
134
+ kernel_size=4,
135
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
136
+ layer_idx=i,
137
+ use_fast_conv1d=True,
138
+ )
139
+
140
+ else:
141
+ return SimpleGLABlock(
142
+ dim=cfg.dim,
143
+ num_heads=cfg.num_heads,
144
+ layer_idx=i,
145
+ expand_k=cfg.expand_k,
146
+ expand_v=cfg.expand_v,
147
+ use_short_conv=cfg.use_short_conv,
148
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
149
+ )
150
+
151
+ def crossatt_block(i):
152
+ if i in cfg.crossatt_layer_idx:
153
+ return CrossAttention(
154
+ dim=cfg.dim,
155
+ num_heads=cfg.crossatt_num_heads,
156
+ dropout=cfg.crossatt_dropout,
157
+ layer_idx=i,
158
+ )
159
+ else:
160
+ return None
161
+
162
+ self.decoder_layers = nn.ModuleList(
163
+ [
164
+ DecoderBlockWithOptionalCrossAttention(
165
+ simple_gla_block(i),
166
+ crossatt_block(i),
167
+ )
168
+ for i in range(cfg.num_layers)
169
+ ]
170
+ )
171
+
172
+ def forward(
173
+ self,
174
+ encoder_output: torch.Tensor,
175
+ decoder_input: torch.Tensor,
176
+ crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
177
+ text_ids: torch.Tensor | None = None,
178
+ cache: FLACache | None = None,
179
+ ):
180
+ x = decoder_input
181
+ text_freqs = None
182
+
183
+ for layer in self.decoder_layers:
184
+ x = maybe_grad_ckpt(layer)(
185
+ x,
186
+ encoder_output,
187
+ text_freqs=text_freqs,
188
+ cache=cache,
189
+ crossatt_mask=crossatt_mask,
190
+ )
191
+ return x
192
+
193
+ def init_cache(self, max_seq_len, device):
194
+ return FLACache(num_states=len(self.decoder_layers) + 1)
195
+
196
+ def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"):
197
+ return tuple(
198
+ nn.Parameter(
199
+ torch.randn(
200
+ batch_size,
201
+ self.num_heads,
202
+ self.head_dim,
203
+ self.head_dim,
204
+ device=device,
205
+ )
206
+ * scale
207
+ )
208
+ for _ in range(len(self.decoder_layers))
209
+ )
210
+ def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"):
211
+ return tuple(
212
+ (
213
+ nn.Parameter(
214
+ torch.randn(
215
+ batch_size,
216
+ self.num_heads,
217
+ self.head_dim,
218
+ lora,
219
+ device=device,
220
+ )
221
+ * scale
222
+ ),
223
+ nn.Parameter(
224
+ torch.randn(
225
+ batch_size,
226
+ self.num_heads,
227
+ lora,
228
+ self.head_dim,
229
+ device=device,
230
+ )
231
+ * scale
232
+ )
233
+ )
234
+ for _ in range(len(self.decoder_layers))
235
+ )
236
+
237
+ def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int):
238
+ assert self.decoder_layers[layer_idx].crossatt is not None
239
+ x = audio_inputs
240
+ for _, layer in zip(range(layer_idx - 1), self.decoder_layers):
241
+ x = layer(x, None)
242
+ return self.decoder_layers[layer_idx].crossatt._query(x)
243
+
244
+ def forward_first_n_layers(
245
+ self,
246
+ encoder_output: torch.Tensor,
247
+ decoder_input: torch.Tensor,
248
+ n_first_layers: int,
249
+ crossatt_mask: torch.Tensor | None = None,
250
+ cache: FLACache | None = None,
251
+ ):
252
+ x = decoder_input
253
+ if self.text_freqs_embd is not None:
254
+ text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :]
255
+ text_freqs = self.text_freqs_embd(text_freqs)
256
+ else:
257
+ text_freqs = None
258
+
259
+ for layer in self.decoder_layers[:n_first_layers]:
260
+ x = maybe_grad_ckpt(layer)(
261
+ x,
262
+ encoder_output,
263
+ text_freqs=text_freqs,
264
+ cache=cache,
265
+ crossatt_mask=crossatt_mask,
266
+ )
267
+ return x
268
+
269
+ def prefill(
270
+ self,
271
+ encoder_output: torch.Tensor,
272
+ decoder_input: torch.Tensor,
273
+ crossatt_mask: torch.Tensor | None = None,
274
+ cache: FLACache | None = None,
275
+ ):
276
+ return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask)
277
+
278
+ def decode_one(
279
+ self,
280
+ encoder_output: torch.Tensor,
281
+ decoder_input: torch.Tensor,
282
+ cache: Cache,
283
+ text_freqs: torch.Tensor | None = None,
284
+ crossatt_mask: torch.Tensor | None = None,
285
+ ):
286
+ x = decoder_input
287
+ for layer in self.decoder_layers:
288
+ x = layer(
289
+ x,
290
+ encoder_output,
291
+ text_freqs=text_freqs,
292
+ cache=cache,
293
+ crossatt_mask=crossatt_mask,
294
+ )
295
+ return x