ruslanmv commited on
Commit
9ae6548
·
1 Parent(s): e09268d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -150
app.py CHANGED
@@ -2,55 +2,70 @@
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
-
6
  import os
7
  import sys
8
- import re
9
  import base64
10
  import struct
11
  import textwrap
12
  import requests
13
  import atexit
14
- from typing import List, Dict, Tuple, Generator, Any
15
 
16
  # --- Fast, safe defaults ---
17
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
18
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
19
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
20
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
21
- os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage quirks
22
 
23
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
24
  from dotenv import load_dotenv
25
  load_dotenv()
26
 
27
- # --- NumPy sanity (Torch 2.2.x prefers NumPy 1.x) ---
28
  import numpy as _np
29
  if int(_np.__version__.split(".", 1)[0]) >= 2:
30
  raise RuntimeError(
31
- f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4)."
32
  )
33
 
34
- # --- Pandas compat shim (Gradio uses 'future.no_silent_downcasting') ---
35
  try:
36
  import pandas as pd
37
- from pandas._config.config import register_option
38
  try:
39
- pd.get_option("future.no_silent_downcasting")
 
40
  except Exception:
41
- # Register the option if missing so pd.option_context(...) won't crash
42
- register_option("future.no_silent_downcasting", False, validator=None, doc="compat shim for Gradio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  except Exception:
44
- pd = None # ok
45
 
46
- # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
47
  try:
48
  import spaces # Required for ZeroGPU on HF
49
  except Exception:
50
  class _SpacesShim:
51
  def GPU(self, *args, **kwargs):
52
- def _wrap(fn):
53
- return fn
54
  return _wrap
55
  spaces = _SpacesShim()
56
 
@@ -62,7 +77,7 @@ import numpy as np
62
  from huggingface_hub import HfApi, hf_hub_download
63
  from llama_cpp import Llama
64
 
65
- # --- Audio decoding (pure ffmpeg-python; no torchaudio) ---
66
  import ffmpeg
67
 
68
  # --- TTS Libraries ---
@@ -77,30 +92,32 @@ import langid
77
  import emoji
78
  import noisereduce as nr
79
 
80
-
81
  # ===================================================================================
82
  # 2) GLOBALS & HELPERS
83
  # ===================================================================================
84
 
85
- def _ensure_nltk() -> None:
86
- # Newer NLTK splits data into 'punkt' and 'punkt_tab'
87
- for pkg in ("punkt", "punkt_tab"):
 
 
 
 
 
 
88
  try:
89
- if pkg == "punkt":
90
- nltk.data.find("tokenizers/punkt")
91
- else:
92
- nltk.data.find("tokenizers/punkt_tab")
93
- except LookupError:
94
- nltk.download(pkg, quiet=True)
95
 
96
  _ensure_nltk()
97
 
98
- # Models & caches
99
  tts_model: Xtts | None = None
100
  llm_model: Llama | None = None
101
-
102
- # Store latents as NumPy on CPU for portability; convert to device at inference time
103
- voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
104
 
105
  # Config
106
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -110,7 +127,10 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
110
  SENTENCE_SPLIT_LENGTH = 250
111
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
112
 
113
- # System prompts and roles
 
 
 
114
  default_system_message = (
115
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
116
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
@@ -123,25 +143,7 @@ ROLE_PROMPTS["Pirate"] = (
123
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
124
  )
125
 
126
-
127
- # ---------- tiny utilities ----------
128
- def _model_device(m: torch.nn.Module) -> torch.device:
129
- try:
130
- return next(m.parameters()).device
131
- except StopIteration:
132
- return torch.device("cpu")
133
-
134
- def _to_device_float_tensor(x: Any, device: torch.device) -> torch.Tensor:
135
- if isinstance(x, np.ndarray):
136
- return torch.from_numpy(x).float().to(device)
137
- if torch.is_tensor(x):
138
- return x.to(device, dtype=torch.float32)
139
- return torch.as_tensor(x, dtype=torch.float32, device=device)
140
-
141
- def _latents_for_device(latents: Tuple[Any, Any], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
142
- gpt_cond, spk = latents
143
- return _to_device_float_tensor(gpt_cond, device), _to_device_float_tensor(spk, device)
144
-
145
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
146
  if pcm_data.startswith(b"RIFF"):
147
  return pcm_data
@@ -157,19 +159,14 @@ def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit
157
  return header + pcm_data
158
 
159
  def split_sentences(text: str, max_len: int) -> List[str]:
160
- # Try NLTK; if it fails for any reason, fallback to a simple regex splitter.
161
- try:
162
- sentences = nltk.sent_tokenize(text)
163
- except Exception:
164
- sentences = re.split(r"(?<=[\.\!\?])\s+", text)
165
- chunks: List[str] = []
166
- for sent in sentences:
167
- if len(sent) > max_len:
168
- chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
169
  else:
170
- if sent:
171
- chunks.append(sent)
172
- return chunks
173
 
174
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
175
  prompt = f"<|system|>\n{system_message}</s>"
@@ -179,7 +176,6 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
179
  prompt += f"<|user|>\n{message}</s><|assistant|>"
180
  return prompt
181
 
182
-
183
  # ---------- robust audio decode (mono via ffmpeg) ----------
184
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
185
  """
@@ -196,23 +192,17 @@ def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
196
  pcm = np.frombuffer(out, dtype=np.int16)
197
  if pcm.size == 0:
198
  raise RuntimeError("ffmpeg produced empty audio.")
199
- wav = (pcm.astype(np.float32) / 32767.0)
200
- return wav
201
  except ffmpeg.Error as e:
202
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
203
 
204
-
205
- # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
206
  def _patched_load_audio(audiopath: str, load_sr: int):
207
  """
208
- Match XTTS' expected return type:
209
- - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1],
210
- already resampled to `load_sr`.
211
- - DO NOT return (audio, sr) tuple.
212
  """
213
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
214
- import torch as _torch # local import to avoid circularities
215
- audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
216
  return audio
217
 
218
  xtts_module.load_audio = _patched_load_audio
@@ -222,14 +212,12 @@ try:
222
  except Exception:
223
  pass
224
 
225
-
226
  def _coqui_cache_dir() -> str:
227
- # Matches what TTS uses on Linux: ~/.local/share/tts
228
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
229
 
230
-
231
  # ===================================================================================
232
- # 3) PRECACHE & MODEL LOADERS (CPU at startup to avoid ZeroGPU issues)
233
  # ===================================================================================
234
 
235
  def precache_assets() -> None:
@@ -262,9 +250,8 @@ def precache_assets() -> None:
262
  except Exception as e:
263
  print(f"Warning: GGUF pre-cache error: {e}")
264
 
265
-
266
- def _load_xtts(device: str = "cpu") -> Xtts:
267
- """Load XTTS from the local cache. Keep CPU at startup to avoid ZeroGPU device mixups."""
268
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
269
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
270
  ModelManager().download_model(model_name) # idempotent
@@ -283,20 +270,16 @@ def _load_xtts(device: str = "cpu") -> Xtts:
283
  print("XTTS model loaded.")
284
  return model
285
 
286
-
287
- def _load_llama() -> Llama:
288
- """
289
- Load Llama (Zephyr GGUF).
290
- Keep simple & robust: default to CPU (works everywhere).
291
- """
292
- print("Loading LLM (Zephyr GGUF)...")
293
  zephyr_model_path = hf_hub_download(
294
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
295
  filename="zephyr-7b-beta.Q5_K_M.gguf"
296
  )
297
  llm = Llama(
298
  model_path=zephyr_model_path,
299
- n_gpu_layers=0, # CPU-only for reliability across Spaces/ZeroGPU
300
  n_ctx=4096,
301
  n_batch=512,
302
  verbose=False
@@ -304,27 +287,26 @@ def _load_llama() -> Llama:
304
  print("LLM loaded (CPU).")
305
  return llm
306
 
307
-
308
  def init_models_and_latents() -> None:
309
  """
310
- Preload models on CPU and compute voice latents on CPU.
311
- This avoids ZeroGPU's "mixed device" errors from torchaudio-based resampling.
312
  """
313
  global tts_model, llm_model, voice_latents
314
 
 
 
315
  if tts_model is None:
316
- tts_model = _load_xtts(device="cpu") # always CPU at startup
 
 
317
 
318
  if llm_model is None:
319
- llm_model = _load_llama()
320
 
 
321
  if not voice_latents:
322
  print("Computing voice conditioning latents (CPU)...")
323
- # Ensure the TTS model is on CPU while computing latents
324
- orig_dev = _model_device(tts_model)
325
- if orig_dev.type != "cpu":
326
- tts_model.to("cpu")
327
-
328
  with torch.no_grad():
329
  for role, filename in [
330
  ("Cloée", "cloee-1.wav"),
@@ -333,21 +315,10 @@ def init_models_and_latents() -> None:
333
  ("Thera", "thera-1.wav"),
334
  ]:
335
  path = os.path.join("voices", filename)
336
- gpt_lat, spk_emb = tts_model.get_conditioning_latents(
337
  audio_path=path, gpt_cond_len=30, max_ref_length=60
338
  )
339
- # Store as NumPy on CPU; convert to device on demand later
340
- voice_latents[role] = (
341
- gpt_lat.detach().cpu().numpy(),
342
- spk_emb.detach().cpu().numpy(),
343
- )
344
-
345
- # Return model to original device (keep CPU at startup for safety)
346
- if orig_dev.type != "cpu":
347
- tts_model.to(orig_dev)
348
-
349
- print("Voice latents ready.")
350
-
351
 
352
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
353
  def _close_llm():
@@ -359,7 +330,6 @@ def _close_llm():
359
  pass
360
  atexit.register(_close_llm)
361
 
362
-
363
  # ===================================================================================
364
  # 4) INFERENCE HELPERS
365
  # ===================================================================================
@@ -367,17 +337,17 @@ atexit.register(_close_llm)
367
  def generate_text_stream(llm_instance: Llama, prompt: str,
368
  history: List[Tuple[str, str | None]],
369
  system_message_text: str) -> Generator[str, None, None]:
370
- formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
371
  stream = llm_instance(
372
- formatted_prompt,
373
  temperature=0.7,
374
  max_tokens=512,
375
  top_p=0.95,
376
  stop=LLM_STOP_WORDS,
377
  stream=True
378
  )
379
- for response in stream:
380
- ch = response["choices"][0]["text"]
381
  try:
382
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
383
  except Exception:
@@ -386,29 +356,31 @@ def generate_text_stream(llm_instance: Llama, prompt: str,
386
  continue
387
  yield ch
388
 
 
 
 
 
 
 
 
389
 
390
  def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
391
- latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
392
- # Convert stored CPU NumPy latents to tensors on the model's current device
393
- device = _model_device(tts_instance)
394
- gpt_cond_latent_t, speaker_embedding_t = _latents_for_device(latents, device)
395
-
396
  try:
397
  for chunk in tts_instance.inference_stream(
398
  text=text,
399
  language=language,
400
- gpt_cond_latent=gpt_cond_latent_t,
401
- speaker_embedding=speaker_embedding_t,
402
  temperature=0.85,
403
  ):
404
  if chunk is None:
405
  continue
406
- # chunk: torch.FloatTensor [N] or [1, N], float32 in [-1, 1]
407
- f32 = chunk.detach().cpu().numpy().squeeze()
408
- f32 = np.clip(f32, -1.0, 1.0).astype(np.float32)
409
  s16 = (f32 * 32767.0).astype(np.int16)
410
  yield s16.tobytes()
411
-
412
  except RuntimeError as e:
413
  print(f"Error during TTS inference: {e}")
414
  if "device-side assert" in str(e) and api:
@@ -418,32 +390,34 @@ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
418
  except Exception:
419
  pass
420
 
421
-
422
  # ===================================================================================
423
- # 5) ZERO-GPU ENTRYPOINT (safe on native GPU as well)
424
  # ===================================================================================
425
 
426
- @spaces.GPU(duration=120) # GPU ops must occur inside this function when on ZeroGPU
427
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
428
  if secret_token_input != SECRET_TOKEN:
429
  raise gr.Error("Invalid secret token provided.")
430
  if not input_text:
431
  return []
432
 
433
- # Ensure models/latents exist (loaded on CPU)
434
  if tts_model is None or llm_model is None or not voice_latents:
435
  init_models_and_latents()
436
 
437
- # During the GPU window, move XTTS to CUDA if available; otherwise stay on CPU
438
  try:
439
  if torch.cuda.is_available():
440
  tts_model.to("cuda")
 
441
  else:
442
  tts_model.to("cpu")
 
443
  except Exception:
444
  tts_model.to("cpu")
 
445
 
446
- # Generate story text (LLM kept CPU for simplicity & reliability)
447
  history: List[Tuple[str, str | None]] = [(input_text, None)]
448
  full_story_text = "".join(
449
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
@@ -460,10 +434,13 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
460
  if not any(c.isalnum() for c in sentence):
461
  continue
462
 
463
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
 
 
 
464
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
465
 
466
- # Optional noise reduction (best-effort)
467
  try:
468
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
469
  if data_s16.size > 0:
@@ -478,7 +455,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
478
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
479
  results.append({"text": sentence, "audio": b64_wav})
480
 
481
- # Leave model on CPU after the ZeroGPU window
482
  try:
483
  tts_model.to("cpu")
484
  except Exception:
@@ -486,7 +463,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
486
 
487
  return results
488
 
489
-
490
  # ===================================================================================
491
  # 6) STARTUP: PRECACHE & UI
492
  # ===================================================================================
@@ -501,21 +477,17 @@ def build_ui() -> gr.Interface:
501
  ],
502
  outputs=gr.JSON(label="Story and Audio Output"),
503
  title="AI Storyteller with ZeroGPU",
504
- description="Enter a prompt to generate a short story with voice narration. Uses GPU only within the generation call when available.",
505
- flagging_mode="never",
506
- # removed allow_flagging (deprecated) to silence warning
507
  )
508
 
509
  if __name__ == "__main__":
510
- print("===== Startup: pre-cache assets and preload models (CPU) =====")
511
- print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}")
512
- precache_assets() # 1) download everything to disk
513
- init_models_and_latents() # 2) load models on CPU + compute voice latents on CPU
514
  print("Models and assets ready. Launching UI...")
515
 
516
  demo = build_ui()
517
- demo.queue().launch(
518
- server_name="0.0.0.0",
519
- server_port=int(os.getenv("PORT", "7860")),
520
- ssr_mode=False, # disable experimental SSR noise
521
- )
 
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
 
5
  import os
6
  import sys
 
7
  import base64
8
  import struct
9
  import textwrap
10
  import requests
11
  import atexit
12
+ from typing import List, Dict, Tuple, Generator
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
19
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path
20
 
21
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
+ # --- NumPy sanity with torch 2.2.x ---
26
  import numpy as _np
27
  if int(_np.__version__.split(".", 1)[0]) >= 2:
28
  raise RuntimeError(
29
+ f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space."
30
  )
31
 
32
+ # --- Panda shim for Gradio on pandas<2.2 (no 'future.no_silent_downcasting') ---
33
  try:
34
  import pandas as pd
 
35
  try:
36
+ with pd.option_context("future.no_silent_downcasting", True):
37
+ pass
38
  except Exception:
39
+ from contextlib import contextmanager
40
+ _orig_option_context = pd.option_context
41
+
42
+ @contextmanager
43
+ def _patched_option_context(*args, **kwargs):
44
+ # filter out unsupported option pairs
45
+ filtered = []
46
+ i = 0
47
+ while i < len(args):
48
+ key = args[i]
49
+ val = args[i + 1] if i + 1 < len(args) else None
50
+ if key == "future.no_silent_downcasting":
51
+ i += 2
52
+ continue
53
+ filtered.extend([key, val])
54
+ i += 2
55
+ with _orig_option_context(*filtered, **kwargs):
56
+ yield
57
+
58
+ pd.option_context = _patched_option_context # type: ignore[attr-defined]
59
  except Exception:
60
+ pd = None # noqa: N816
61
 
62
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) ---
63
  try:
64
  import spaces # Required for ZeroGPU on HF
65
  except Exception:
66
  class _SpacesShim:
67
  def GPU(self, *args, **kwargs):
68
+ def _wrap(fn): return fn
 
69
  return _wrap
70
  spaces = _SpacesShim()
71
 
 
77
  from huggingface_hub import HfApi, hf_hub_download
78
  from llama_cpp import Llama
79
 
80
+ # --- Audio decode via ffmpeg-python (no torchaudio.load) ---
81
  import ffmpeg
82
 
83
  # --- TTS Libraries ---
 
92
  import emoji
93
  import noisereduce as nr
94
 
 
95
  # ===================================================================================
96
  # 2) GLOBALS & HELPERS
97
  # ===================================================================================
98
 
99
+ # Ensure NLTK resources (both 'punkt' and new 'punkt_tab' on newer NLTK)
100
+ def _ensure_nltk():
101
+ try:
102
+ nltk.data.find("tokenizers/punkt")
103
+ except LookupError:
104
+ nltk.download("punkt", quiet=True)
105
+ try:
106
+ nltk.data.find("tokenizers/punkt_tab/english")
107
+ except LookupError:
108
  try:
109
+ nltk.download("punkt_tab", quiet=True)
110
+ except Exception:
111
+ # fallback: downloading 'punkt' already satisfies older versions
112
+ pass
 
 
113
 
114
  _ensure_nltk()
115
 
116
+ # Cached models & latents
117
  tts_model: Xtts | None = None
118
  llm_model: Llama | None = None
119
+ # store as torch.Tensors (CPU at startup)
120
+ voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
 
121
 
122
  # Config
123
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
127
  SENTENCE_SPLIT_LENGTH = 250
128
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
129
 
130
+ # IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it.
131
+ USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true"
132
+
133
+ # Roles & prompts
134
  default_system_message = (
135
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
136
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
 
143
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
144
  )
145
 
146
+ # ---------- small utils ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
148
  if pcm_data.startswith(b"RIFF"):
149
  return pcm_data
 
159
  return header + pcm_data
160
 
161
  def split_sentences(text: str, max_len: int) -> List[str]:
162
+ sentences = nltk.sent_tokenize(text)
163
+ out: List[str] = []
164
+ for s in sentences:
165
+ if len(s) > max_len:
166
+ out.extend(textwrap.wrap(s, max_len, break_long_words=True))
 
 
 
 
167
  else:
168
+ out.append(s)
169
+ return out
 
170
 
171
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
172
  prompt = f"<|system|>\n{system_message}</s>"
 
176
  prompt += f"<|user|>\n{message}</s><|assistant|>"
177
  return prompt
178
 
 
179
  # ---------- robust audio decode (mono via ffmpeg) ----------
180
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
181
  """
 
192
  pcm = np.frombuffer(out, dtype=np.int16)
193
  if pcm.size == 0:
194
  raise RuntimeError("ffmpeg produced empty audio.")
195
+ return (pcm.astype(np.float32) / 32767.0)
 
196
  except ffmpeg.Error as e:
197
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
198
 
199
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ----------
 
200
  def _patched_load_audio(audiopath: str, load_sr: int):
201
  """
202
+ Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr.
 
 
 
203
  """
204
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
205
+ audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
 
206
  return audio
207
 
208
  xtts_module.load_audio = _patched_load_audio
 
212
  except Exception:
213
  pass
214
 
 
215
  def _coqui_cache_dir() -> str:
216
+ # Coqui cache default on Linux
217
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
218
 
 
219
  # ===================================================================================
220
+ # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
221
  # ===================================================================================
222
 
223
  def precache_assets() -> None:
 
250
  except Exception as e:
251
  print(f"Warning: GGUF pre-cache error: {e}")
252
 
253
+ def _load_xtts(device: str) -> Xtts:
254
+ """Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility."""
 
255
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
256
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
257
  ModelManager().download_model(model_name) # idempotent
 
270
  print("XTTS model loaded.")
271
  return model
272
 
273
+ def _load_llama_cpu_only() -> Llama:
274
+ """Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly)."""
275
+ print("Loading LLM (Zephyr GGUF) on CPU...")
 
 
 
 
276
  zephyr_model_path = hf_hub_download(
277
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
278
  filename="zephyr-7b-beta.Q5_K_M.gguf"
279
  )
280
  llm = Llama(
281
  model_path=zephyr_model_path,
282
+ n_gpu_layers=0, # never touch CUDA at startup
283
  n_ctx=4096,
284
  n_batch=512,
285
  verbose=False
 
287
  print("LLM loaded (CPU).")
288
  return llm
289
 
 
290
  def init_models_and_latents() -> None:
291
  """
292
+ Preload TTS and LLM on CPU and compute voice latents on CPU.
293
+ This avoids any CUDA tensors outside the @spaces.GPU window.
294
  """
295
  global tts_model, llm_model, voice_latents
296
 
297
+ target_device = "cpu" # FORCE CPU at startup for ZeroGPU compatibility
298
+
299
  if tts_model is None:
300
+ tts_model = _load_xtts(device=target_device)
301
+ else:
302
+ tts_model.to("cpu")
303
 
304
  if llm_model is None:
305
+ llm_model = _load_llama_cpu_only()
306
 
307
+ # Pre-compute latents once on CPU (uses our ffmpeg loader)
308
  if not voice_latents:
309
  print("Computing voice conditioning latents (CPU)...")
 
 
 
 
 
310
  with torch.no_grad():
311
  for role, filename in [
312
  ("Cloée", "cloee-1.wav"),
 
315
  ("Thera", "thera-1.wav"),
316
  ]:
317
  path = os.path.join("voices", filename)
318
+ voice_latents[role] = tts_model.get_conditioning_latents(
319
  audio_path=path, gpt_cond_len=30, max_ref_length=60
320
  )
321
+ print("Voice latents ready (CPU).")
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
324
  def _close_llm():
 
330
  pass
331
  atexit.register(_close_llm)
332
 
 
333
  # ===================================================================================
334
  # 4) INFERENCE HELPERS
335
  # ===================================================================================
 
337
  def generate_text_stream(llm_instance: Llama, prompt: str,
338
  history: List[Tuple[str, str | None]],
339
  system_message_text: str) -> Generator[str, None, None]:
340
+ formatted = format_prompt_zephyr(prompt, history, system_message_text)
341
  stream = llm_instance(
342
+ formatted,
343
  temperature=0.7,
344
  max_tokens=512,
345
  top_p=0.95,
346
  stop=LLM_STOP_WORDS,
347
  stream=True
348
  )
349
+ for resp in stream:
350
+ ch = resp["choices"][0]["text"]
351
  try:
352
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
353
  except Exception:
 
356
  continue
357
  yield ch
358
 
359
+ def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
360
+ g, s = latents
361
+ if isinstance(g, torch.Tensor):
362
+ g = g.to(device)
363
+ if isinstance(s, torch.Tensor):
364
+ s = s.to(device)
365
+ return g, s
366
 
367
  def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
368
+ latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]:
369
+ gpt_cond_latent, speaker_embedding = latents
 
 
 
370
  try:
371
  for chunk in tts_instance.inference_stream(
372
  text=text,
373
  language=language,
374
+ gpt_cond_latent=gpt_cond_latent,
375
+ speaker_embedding=speaker_embedding,
376
  temperature=0.85,
377
  ):
378
  if chunk is None:
379
  continue
380
+ f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32)
381
+ f32 = np.clip(f32, -1.0, 1.0)
 
382
  s16 = (f32 * 32767.0).astype(np.int16)
383
  yield s16.tobytes()
 
384
  except RuntimeError as e:
385
  print(f"Error during TTS inference: {e}")
386
  if "device-side assert" in str(e) and api:
 
390
  except Exception:
391
  pass
392
 
 
393
  # ===================================================================================
394
+ # 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
395
  # ===================================================================================
396
 
397
+ @spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call
398
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
399
  if secret_token_input != SECRET_TOKEN:
400
  raise gr.Error("Invalid secret token provided.")
401
  if not input_text:
402
  return []
403
 
404
+ # Ensure models/latents exist (CPU)
405
  if tts_model is None or llm_model is None or not voice_latents:
406
  init_models_and_latents()
407
 
408
+ # If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU.
409
  try:
410
  if torch.cuda.is_available():
411
  tts_model.to("cuda")
412
+ device = torch.device("cuda")
413
  else:
414
  tts_model.to("cpu")
415
+ device = torch.device("cpu")
416
  except Exception:
417
  tts_model.to("cpu")
418
+ device = torch.device("cpu")
419
 
420
+ # Generate story text (LLM on CPU)
421
  history: List[Tuple[str, str | None]] = [(input_text, None)]
422
  full_story_text = "".join(
423
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
 
434
  if not any(c.isalnum() for c in sentence):
435
  continue
436
 
437
+ # Move cached latents to the same device as the model for this call
438
+ lat_dev = _latents_to_device(voice_latents[chatbot_role], device)
439
+
440
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev)
441
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
442
 
443
+ # Optional noise reduction (best-effort, CPU)
444
  try:
445
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
446
  if data_s16.size > 0:
 
455
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
456
  results.append({"text": sentence, "audio": b64_wav})
457
 
458
+ # Return XTTS to CPU to release GPU instantly
459
  try:
460
  tts_model.to("cpu")
461
  except Exception:
 
463
 
464
  return results
465
 
 
466
  # ===================================================================================
467
  # 6) STARTUP: PRECACHE & UI
468
  # ===================================================================================
 
477
  ],
478
  outputs=gr.JSON(label="Story and Audio Output"),
479
  title="AI Storyteller with ZeroGPU",
480
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
481
+ flagging_mode="never", # avoid deprecated allow_flagging path
482
+ analytics_enabled=False,
483
  )
484
 
485
  if __name__ == "__main__":
486
+ print("===== Startup: pre-cache assets and preload models =====")
487
+ print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)")
488
+ precache_assets() # 1) download everything to disk
489
+ init_models_and_latents() # 2) load on CPU + compute voice latents on CPU
490
  print("Models and assets ready. Launching UI...")
491
 
492
  demo = build_ui()
493
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))