ruslanmv commited on
Commit
bba59ca
·
1 Parent(s): caa34c2

Fixes versions

Browse files
Files changed (2) hide show
  1. app.py +149 -193
  2. requirements.txt +1 -1
app.py CHANGED
@@ -2,68 +2,79 @@
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
-
6
  import os
7
  import sys
8
- import re
9
  import base64
10
  import struct
11
  import textwrap
12
  import requests
13
  import atexit
14
- import inspect
15
- from typing import List, Dict, Tuple, Generator, Any
16
 
17
  # --- Fast, safe defaults ---
18
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
19
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
20
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
21
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
22
- os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage quirks
23
 
24
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
25
  from dotenv import load_dotenv
26
  load_dotenv()
27
 
28
- # --- NumPy sanity (Torch 2.2.x prefers NumPy 1.x) ---
29
  import numpy as _np
30
  if int(_np.__version__.split(".", 1)[0]) >= 2:
31
  raise RuntimeError(
32
- f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4)."
 
 
 
 
 
 
 
 
33
  )
34
 
35
- # --- Pandas compat shims (Gradio & mixed versions) ---
 
36
  try:
37
  import pandas as pd
38
- from pandas._config.config import register_option
39
- # Option shim
40
  try:
41
- pd.get_option("future.no_silent_downcasting")
42
- except Exception:
43
- register_option("future.no_silent_downcasting", False, validator=None, doc="compat shim for Gradio")
44
- # infer_objects(copy=...) shim (older pandas lacks 'copy' kwarg)
45
- if hasattr(pd, "DataFrame"):
46
- try:
47
- sig = inspect.signature(pd.DataFrame.infer_objects)
48
- if "copy" not in sig.parameters:
49
- _orig_infer_objects = pd.DataFrame.infer_objects
50
- def _infer_objects_compat(self, *args, **kwargs):
51
- kwargs.pop("copy", None)
52
- return _orig_infer_objects(self, *args, **kwargs)
53
- pd.DataFrame.infer_objects = _infer_objects_compat
54
- except Exception:
55
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception:
57
- pd = None # ok if pandas is unavailable
58
 
59
- # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
60
  try:
61
  import spaces # Required for ZeroGPU on HF
62
  except Exception:
63
  class _SpacesShim:
64
  def GPU(self, *args, **kwargs):
65
- def _wrap(fn):
66
- return fn
67
  return _wrap
68
  spaces = _SpacesShim()
69
 
@@ -75,7 +86,7 @@ import numpy as np
75
  from huggingface_hub import HfApi, hf_hub_download
76
  from llama_cpp import Llama
77
 
78
- # --- Audio decoding (pure ffmpeg-python; no torchaudio) ---
79
  import ffmpeg
80
 
81
  # --- TTS Libraries ---
@@ -90,30 +101,32 @@ import langid
90
  import emoji
91
  import noisereduce as nr
92
 
93
-
94
  # ===================================================================================
95
  # 2) GLOBALS & HELPERS
96
  # ===================================================================================
97
 
98
- def _ensure_nltk() -> None:
99
- # Newer NLTK splits data into 'punkt' and 'punkt_tab'
100
- for pkg in ("punkt", "punkt_tab"):
 
 
 
 
 
 
101
  try:
102
- if pkg == "punkt":
103
- nltk.data.find("tokenizers/punkt")
104
- else:
105
- nltk.data.find("tokenizers/punkt_tab")
106
- except LookupError:
107
- nltk.download(pkg, quiet=True)
108
 
109
  _ensure_nltk()
110
 
111
- # Models & caches
112
  tts_model: Xtts | None = None
113
  llm_model: Llama | None = None
114
-
115
- # Store latents as NumPy on CPU for portability; convert to device at inference time
116
- voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
117
 
118
  # Config
119
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -123,7 +136,10 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
123
  SENTENCE_SPLIT_LENGTH = 250
124
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
125
 
126
- # System prompts and roles
 
 
 
127
  default_system_message = (
128
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
129
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
@@ -136,25 +152,7 @@ ROLE_PROMPTS["Pirate"] = (
136
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
137
  )
138
 
139
-
140
- # ---------- tiny utilities ----------
141
- def _model_device(m: torch.nn.Module) -> torch.device:
142
- try:
143
- return next(m.parameters()).device
144
- except StopIteration:
145
- return torch.device("cpu")
146
-
147
- def _to_device_float_tensor(x: Any, device: torch.device) -> torch.Tensor:
148
- if isinstance(x, np.ndarray):
149
- return torch.from_numpy(x).float().to(device)
150
- if torch.is_tensor(x):
151
- return x.to(device, dtype=torch.float32)
152
- return torch.as_tensor(x, dtype=torch.float32, device=device)
153
-
154
- def _latents_for_device(latents: Tuple[Any, Any], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
155
- gpt_cond, spk = latents
156
- return _to_device_float_tensor(gpt_cond, device), _to_device_float_tensor(spk, device)
157
-
158
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
159
  if pcm_data.startswith(b"RIFF"):
160
  return pcm_data
@@ -170,19 +168,14 @@ def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit
170
  return header + pcm_data
171
 
172
  def split_sentences(text: str, max_len: int) -> List[str]:
173
- # Try NLTK; if it fails for any reason, fallback to a simple regex splitter.
174
- try:
175
- sentences = nltk.sent_tokenize(text)
176
- except Exception:
177
- sentences = re.split(r"(?<=[\.\!\?])\s+", text)
178
- chunks: List[str] = []
179
- for sent in sentences:
180
- if len(sent) > max_len:
181
- chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
182
  else:
183
- if sent:
184
- chunks.append(sent)
185
- return chunks
186
 
187
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
188
  prompt = f"<|system|>\n{system_message}</s>"
@@ -192,7 +185,6 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
192
  prompt += f"<|user|>\n{message}</s><|assistant|>"
193
  return prompt
194
 
195
-
196
  # ---------- robust audio decode (mono via ffmpeg) ----------
197
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
198
  """
@@ -209,23 +201,17 @@ def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
209
  pcm = np.frombuffer(out, dtype=np.int16)
210
  if pcm.size == 0:
211
  raise RuntimeError("ffmpeg produced empty audio.")
212
- wav = (pcm.astype(np.float32) / 32767.0)
213
- return wav
214
  except ffmpeg.Error as e:
215
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
216
 
217
-
218
- # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
219
  def _patched_load_audio(audiopath: str, load_sr: int):
220
  """
221
- Match XTTS' expected return type:
222
- - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1],
223
- already resampled to `load_sr`.
224
- - DO NOT return (audio, sr) tuple.
225
  """
226
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
227
- import torch as _torch # local import to avoid circularities
228
- audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
229
  return audio
230
 
231
  xtts_module.load_audio = _patched_load_audio
@@ -235,14 +221,12 @@ try:
235
  except Exception:
236
  pass
237
 
238
-
239
  def _coqui_cache_dir() -> str:
240
- # Matches what TTS uses on Linux: ~/.local/share/tts
241
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
242
 
243
-
244
  # ===================================================================================
245
- # 3) PRECACHE & MODEL LOADERS (CPU at startup to avoid ZeroGPU issues)
246
  # ===================================================================================
247
 
248
  def precache_assets() -> None:
@@ -275,9 +259,8 @@ def precache_assets() -> None:
275
  except Exception as e:
276
  print(f"Warning: GGUF pre-cache error: {e}")
277
 
278
-
279
- def _load_xtts(device: str = "cpu") -> Xtts:
280
- """Load XTTS from the local cache. Keep CPU at startup to avoid ZeroGPU device mixups."""
281
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
282
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
283
  ModelManager().download_model(model_name) # idempotent
@@ -296,20 +279,16 @@ def _load_xtts(device: str = "cpu") -> Xtts:
296
  print("XTTS model loaded.")
297
  return model
298
 
299
-
300
- def _load_llama() -> Llama:
301
- """
302
- Load Llama (Zephyr GGUF).
303
- Keep simple & robust: default to CPU (works everywhere).
304
- """
305
- print("Loading LLM (Zephyr GGUF)...")
306
  zephyr_model_path = hf_hub_download(
307
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
308
  filename="zephyr-7b-beta.Q5_K_M.gguf"
309
  )
310
  llm = Llama(
311
  model_path=zephyr_model_path,
312
- n_gpu_layers=0, # CPU-only for reliability across Spaces/ZeroGPU
313
  n_ctx=4096,
314
  n_batch=512,
315
  verbose=False
@@ -317,27 +296,26 @@ def _load_llama() -> Llama:
317
  print("LLM loaded (CPU).")
318
  return llm
319
 
320
-
321
  def init_models_and_latents() -> None:
322
  """
323
- Preload models on CPU and compute voice latents on CPU.
324
- This avoids ZeroGPU's "mixed device" errors from torchaudio-based resampling.
325
  """
326
  global tts_model, llm_model, voice_latents
327
 
 
 
328
  if tts_model is None:
329
- tts_model = _load_xtts(device="cpu") # always CPU at startup
 
 
330
 
331
  if llm_model is None:
332
- llm_model = _load_llama()
333
 
 
334
  if not voice_latents:
335
  print("Computing voice conditioning latents (CPU)...")
336
- # Ensure the TTS model is on CPU while computing latents
337
- orig_dev = _model_device(tts_model)
338
- if orig_dev.type != "cpu":
339
- tts_model.to("cpu")
340
-
341
  with torch.no_grad():
342
  for role, filename in [
343
  ("Cloée", "cloee-1.wav"),
@@ -346,21 +324,10 @@ def init_models_and_latents() -> None:
346
  ("Thera", "thera-1.wav"),
347
  ]:
348
  path = os.path.join("voices", filename)
349
- gpt_lat, spk_emb = tts_model.get_conditioning_latents(
350
  audio_path=path, gpt_cond_len=30, max_ref_length=60
351
  )
352
- # Store as NumPy on CPU; convert to device on demand later
353
- voice_latents[role] = (
354
- gpt_lat.detach().cpu().numpy(),
355
- spk_emb.detach().cpu().numpy(),
356
- )
357
-
358
- # Return model to original device (keep CPU at startup for safety)
359
- if orig_dev.type != "cpu":
360
- tts_model.to(orig_dev)
361
-
362
- print("Voice latents ready.")
363
-
364
 
365
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
366
  def _close_llm():
@@ -372,7 +339,6 @@ def _close_llm():
372
  pass
373
  atexit.register(_close_llm)
374
 
375
-
376
  # ===================================================================================
377
  # 4) INFERENCE HELPERS
378
  # ===================================================================================
@@ -380,17 +346,17 @@ atexit.register(_close_llm)
380
  def generate_text_stream(llm_instance: Llama, prompt: str,
381
  history: List[Tuple[str, str | None]],
382
  system_message_text: str) -> Generator[str, None, None]:
383
- formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
384
  stream = llm_instance(
385
- formatted_prompt,
386
  temperature=0.7,
387
  max_tokens=512,
388
  top_p=0.95,
389
  stop=LLM_STOP_WORDS,
390
  stream=True
391
  )
392
- for response in stream:
393
- ch = response["choices"][0]["text"]
394
  try:
395
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
396
  except Exception:
@@ -399,77 +365,68 @@ def generate_text_stream(llm_instance: Llama, prompt: str,
399
  continue
400
  yield ch
401
 
402
-
403
- def _extract_waveform_to_numpy(wav_any: Any) -> np.ndarray:
404
- """
405
- Normalize various XTTS inference() return shapes/types to 1-D float32 numpy in [-1, 1].
406
- """
407
- if isinstance(wav_any, dict):
408
- for k in ("wav", "audio", "samples", "waveform"):
409
- if k in wav_any:
410
- wav_any = wav_any[k]
411
- break
412
- if torch.is_tensor(wav_any):
413
- arr = wav_any.detach().cpu().numpy()
414
- else:
415
- arr = np.asarray(wav_any)
416
- arr = np.squeeze(arr).astype(np.float32)
417
- # If not already normalized, attempt to scale if max > 1 (heuristic)
418
- maxabs = np.max(np.abs(arr)) if arr.size else 1.0
419
- if maxabs > 1.5: # likely int16 or higher-amplitude float
420
- arr = arr / 32767.0
421
- arr = np.clip(arr, -1.0, 1.0)
422
- return arr
423
-
424
-
425
- def synthesize_sentence_pcm16(tts_instance: Xtts, text: str, language: str,
426
- latents: Tuple[np.ndarray, np.ndarray]) -> bytes:
427
- """
428
- Use non-streaming XTTS inference() to avoid GPT2InferenceModel streaming bug.
429
- Returns PCM16 bytes at 24 kHz mono.
430
- """
431
- device = _model_device(tts_instance)
432
- gpt_cond_latent_t, speaker_embedding_t = _latents_for_device(latents, device)
433
-
434
- with torch.no_grad():
435
- out = tts_instance.inference(
436
  text=text,
437
  language=language,
438
- gpt_cond_latent=gpt_cond_latent_t,
439
- speaker_embedding=speaker_embedding_t,
440
  temperature=0.85,
441
- )
442
-
443
- f32 = _extract_waveform_to_numpy(out)
444
- s16 = (np.clip(f32, -1.0, 1.0) * 32767.0).astype(np.int16)
445
- return s16.tobytes()
446
-
 
 
 
 
 
 
 
 
 
447
 
448
  # ===================================================================================
449
- # 5) ZERO-GPU ENTRYPOINT (safe on native GPU as well)
450
  # ===================================================================================
451
 
452
- @spaces.GPU(duration=120) # GPU ops happen inside when on ZeroGPU
453
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
454
  if secret_token_input != SECRET_TOKEN:
455
  raise gr.Error("Invalid secret token provided.")
456
  if not input_text:
457
  return []
458
 
459
- # Ensure models/latents exist (loaded on CPU)
460
  if tts_model is None or llm_model is None or not voice_latents:
461
  init_models_and_latents()
462
 
463
- # During the GPU window, move XTTS to CUDA if available; otherwise stay on CPU
464
  try:
465
  if torch.cuda.is_available():
466
  tts_model.to("cuda")
 
467
  else:
468
  tts_model.to("cpu")
 
469
  except Exception:
470
  tts_model.to("cpu")
 
471
 
472
- # Generate story text (LLM kept CPU for simplicity & reliability)
473
  history: List[Tuple[str, str | None]] = [(input_text, None)]
474
  full_story_text = "".join(
475
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
@@ -486,10 +443,13 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
486
  if not any(c.isalnum() for c in sentence):
487
  continue
488
 
489
- # Synthesize whole sentence (non-stream) to avoid streaming bug
490
- pcm_data = synthesize_sentence_pcm16(tts_model, sentence, lang, voice_latents[chatbot_role])
491
 
492
- # Optional noise reduction (best-effort)
 
 
 
493
  try:
494
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
495
  if data_s16.size > 0:
@@ -504,7 +464,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
504
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
505
  results.append({"text": sentence, "audio": b64_wav})
506
 
507
- # Leave model on CPU after the ZeroGPU window
508
  try:
509
  tts_model.to("cpu")
510
  except Exception:
@@ -512,7 +472,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
512
 
513
  return results
514
 
515
-
516
  # ===================================================================================
517
  # 6) STARTUP: PRECACHE & UI
518
  # ===================================================================================
@@ -527,20 +486,17 @@ def build_ui() -> gr.Interface:
527
  ],
528
  outputs=gr.JSON(label="Story and Audio Output"),
529
  title="AI Storyteller with ZeroGPU",
530
- description="Enter a prompt to generate a short story with voice narration. Uses GPU only within the generation call when available.",
531
- flagging_mode="never",
 
532
  )
533
 
534
  if __name__ == "__main__":
535
- print("===== Startup: pre-cache assets and preload models (CPU) =====")
536
- print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}")
537
- precache_assets() # 1) download everything to disk
538
- init_models_and_latents() # 2) load models on CPU + compute voice latents on CPU
539
  print("Models and assets ready. Launching UI...")
540
 
541
  demo = build_ui()
542
- demo.queue().launch(
543
- server_name="0.0.0.0",
544
- server_port=int(os.getenv("PORT", "7860")),
545
- ssr_mode=False, # disable experimental SSR noise
546
- )
 
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
 
5
  import os
6
  import sys
 
7
  import base64
8
  import struct
9
  import textwrap
10
  import requests
11
  import atexit
12
+ from typing import List, Dict, Tuple, Generator
 
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
19
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path
20
 
21
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
+ # --- NumPy sanity with torch 2.2.x ---
26
  import numpy as _np
27
  if int(_np.__version__.split(".", 1)[0]) >= 2:
28
  raise RuntimeError(
29
+ f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space."
30
+ )
31
+
32
+ # --- Transformers sanity for TTS streaming ---
33
+ import transformers as _transformers
34
+ if _transformers.__version__ != "4.36.2":
35
+ raise RuntimeError(
36
+ f"Detected transformers=={_transformers.__version__}. "
37
+ "Please pin transformers==4.36.2 for compatibility with Coqui TTS streaming."
38
  )
39
 
40
+
41
+ # --- Panda shim for Gradio on pandas<2.2 (no 'future.no_silent_downcasting') ---
42
  try:
43
  import pandas as pd
 
 
44
  try:
45
+ with pd.option_context("future.no_silent_downcasting", True):
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  pass
47
+ except Exception:
48
+ from contextlib import contextmanager
49
+ _orig_option_context = pd.option_context
50
+
51
+ @contextmanager
52
+ def _patched_option_context(*args, **kwargs):
53
+ # filter out unsupported option pairs
54
+ filtered = []
55
+ i = 0
56
+ while i < len(args):
57
+ key = args[i]
58
+ val = args[i + 1] if i + 1 < len(args) else None
59
+ if key == "future.no_silent_downcasting":
60
+ i += 2
61
+ continue
62
+ filtered.extend([key, val])
63
+ i += 2
64
+ with _orig_option_context(*filtered, **kwargs):
65
+ yield
66
+
67
+ pd.option_context = _patched_option_context # type: ignore[attr-defined]
68
  except Exception:
69
+ pd = None # noqa: N816
70
 
71
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) ---
72
  try:
73
  import spaces # Required for ZeroGPU on HF
74
  except Exception:
75
  class _SpacesShim:
76
  def GPU(self, *args, **kwargs):
77
+ def _wrap(fn): return fn
 
78
  return _wrap
79
  spaces = _SpacesShim()
80
 
 
86
  from huggingface_hub import HfApi, hf_hub_download
87
  from llama_cpp import Llama
88
 
89
+ # --- Audio decode via ffmpeg-python (no torchaudio.load) ---
90
  import ffmpeg
91
 
92
  # --- TTS Libraries ---
 
101
  import emoji
102
  import noisereduce as nr
103
 
 
104
  # ===================================================================================
105
  # 2) GLOBALS & HELPERS
106
  # ===================================================================================
107
 
108
+ # Ensure NLTK resources (both 'punkt' and new 'punkt_tab' on newer NLTK)
109
+ def _ensure_nltk():
110
+ try:
111
+ nltk.data.find("tokenizers/punkt")
112
+ except LookupError:
113
+ nltk.download("punkt", quiet=True)
114
+ try:
115
+ nltk.data.find("tokenizers/punkt_tab/english")
116
+ except LookupError:
117
  try:
118
+ nltk.download("punkt_tab", quiet=True)
119
+ except Exception:
120
+ # fallback: downloading 'punkt' already satisfies older versions
121
+ pass
 
 
122
 
123
  _ensure_nltk()
124
 
125
+ # Cached models & latents
126
  tts_model: Xtts | None = None
127
  llm_model: Llama | None = None
128
+ # store as torch.Tensors (CPU at startup)
129
+ voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
 
130
 
131
  # Config
132
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
136
  SENTENCE_SPLIT_LENGTH = 250
137
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
138
 
139
+ # IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it.
140
+ USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true"
141
+
142
+ # Roles & prompts
143
  default_system_message = (
144
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
145
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
 
152
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
153
  )
154
 
155
+ # ---------- small utils ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
157
  if pcm_data.startswith(b"RIFF"):
158
  return pcm_data
 
168
  return header + pcm_data
169
 
170
  def split_sentences(text: str, max_len: int) -> List[str]:
171
+ sentences = nltk.sent_tokenize(text)
172
+ out: List[str] = []
173
+ for s in sentences:
174
+ if len(s) > max_len:
175
+ out.extend(textwrap.wrap(s, max_len, break_long_words=True))
 
 
 
 
176
  else:
177
+ out.append(s)
178
+ return out
 
179
 
180
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
181
  prompt = f"<|system|>\n{system_message}</s>"
 
185
  prompt += f"<|user|>\n{message}</s><|assistant|>"
186
  return prompt
187
 
 
188
  # ---------- robust audio decode (mono via ffmpeg) ----------
189
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
190
  """
 
201
  pcm = np.frombuffer(out, dtype=np.int16)
202
  if pcm.size == 0:
203
  raise RuntimeError("ffmpeg produced empty audio.")
204
+ return (pcm.astype(np.float32) / 32767.0)
 
205
  except ffmpeg.Error as e:
206
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
207
 
208
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ----------
 
209
  def _patched_load_audio(audiopath: str, load_sr: int):
210
  """
211
+ Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr.
 
 
 
212
  """
213
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
214
+ audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
 
215
  return audio
216
 
217
  xtts_module.load_audio = _patched_load_audio
 
221
  except Exception:
222
  pass
223
 
 
224
  def _coqui_cache_dir() -> str:
225
+ # Coqui cache default on Linux
226
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
227
 
 
228
  # ===================================================================================
229
+ # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
230
  # ===================================================================================
231
 
232
  def precache_assets() -> None:
 
259
  except Exception as e:
260
  print(f"Warning: GGUF pre-cache error: {e}")
261
 
262
+ def _load_xtts(device: str) -> Xtts:
263
+ """Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility."""
 
264
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
265
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
266
  ModelManager().download_model(model_name) # idempotent
 
279
  print("XTTS model loaded.")
280
  return model
281
 
282
+ def _load_llama_cpu_only() -> Llama:
283
+ """Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly)."""
284
+ print("Loading LLM (Zephyr GGUF) on CPU...")
 
 
 
 
285
  zephyr_model_path = hf_hub_download(
286
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
287
  filename="zephyr-7b-beta.Q5_K_M.gguf"
288
  )
289
  llm = Llama(
290
  model_path=zephyr_model_path,
291
+ n_gpu_layers=0, # never touch CUDA at startup
292
  n_ctx=4096,
293
  n_batch=512,
294
  verbose=False
 
296
  print("LLM loaded (CPU).")
297
  return llm
298
 
 
299
  def init_models_and_latents() -> None:
300
  """
301
+ Preload TTS and LLM on CPU and compute voice latents on CPU.
302
+ This avoids any CUDA tensors outside the @spaces.GPU window.
303
  """
304
  global tts_model, llm_model, voice_latents
305
 
306
+ target_device = "cpu" # FORCE CPU at startup for ZeroGPU compatibility
307
+
308
  if tts_model is None:
309
+ tts_model = _load_xtts(device=target_device)
310
+ else:
311
+ tts_model.to("cpu")
312
 
313
  if llm_model is None:
314
+ llm_model = _load_llama_cpu_only()
315
 
316
+ # Pre-compute latents once on CPU (uses our ffmpeg loader)
317
  if not voice_latents:
318
  print("Computing voice conditioning latents (CPU)...")
 
 
 
 
 
319
  with torch.no_grad():
320
  for role, filename in [
321
  ("Cloée", "cloee-1.wav"),
 
324
  ("Thera", "thera-1.wav"),
325
  ]:
326
  path = os.path.join("voices", filename)
327
+ voice_latents[role] = tts_model.get_conditioning_latents(
328
  audio_path=path, gpt_cond_len=30, max_ref_length=60
329
  )
330
+ print("Voice latents ready (CPU).")
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
333
  def _close_llm():
 
339
  pass
340
  atexit.register(_close_llm)
341
 
 
342
  # ===================================================================================
343
  # 4) INFERENCE HELPERS
344
  # ===================================================================================
 
346
  def generate_text_stream(llm_instance: Llama, prompt: str,
347
  history: List[Tuple[str, str | None]],
348
  system_message_text: str) -> Generator[str, None, None]:
349
+ formatted = format_prompt_zephyr(prompt, history, system_message_text)
350
  stream = llm_instance(
351
+ formatted,
352
  temperature=0.7,
353
  max_tokens=512,
354
  top_p=0.95,
355
  stop=LLM_STOP_WORDS,
356
  stream=True
357
  )
358
+ for resp in stream:
359
+ ch = resp["choices"][0]["text"]
360
  try:
361
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
362
  except Exception:
 
365
  continue
366
  yield ch
367
 
368
+ def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
369
+ g, s = latents
370
+ if isinstance(g, torch.Tensor):
371
+ g = g.to(device)
372
+ if isinstance(s, torch.Tensor):
373
+ s = s.to(device)
374
+ return g, s
375
+
376
+ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
377
+ latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]:
378
+ gpt_cond_latent, speaker_embedding = latents
379
+ try:
380
+ for chunk in tts_instance.inference_stream(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  text=text,
382
  language=language,
383
+ gpt_cond_latent=gpt_cond_latent,
384
+ speaker_embedding=speaker_embedding,
385
  temperature=0.85,
386
+ ):
387
+ if chunk is None:
388
+ continue
389
+ f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32)
390
+ f32 = np.clip(f32, -1.0, 1.0)
391
+ s16 = (f32 * 32767.0).astype(np.int16)
392
+ yield s16.tobytes()
393
+ except RuntimeError as e:
394
+ print(f"Error during TTS inference: {e}")
395
+ if "device-side assert" in str(e) and api:
396
+ try:
397
+ gr.Warning("Critical GPU error. Attempting to restart the Space...")
398
+ api.restart_space(repo_id=repo_id)
399
+ except Exception:
400
+ pass
401
 
402
  # ===================================================================================
403
+ # 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
404
  # ===================================================================================
405
 
406
+ @spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call
407
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
408
  if secret_token_input != SECRET_TOKEN:
409
  raise gr.Error("Invalid secret token provided.")
410
  if not input_text:
411
  return []
412
 
413
+ # Ensure models/latents exist (CPU)
414
  if tts_model is None or llm_model is None or not voice_latents:
415
  init_models_and_latents()
416
 
417
+ # If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU.
418
  try:
419
  if torch.cuda.is_available():
420
  tts_model.to("cuda")
421
+ device = torch.device("cuda")
422
  else:
423
  tts_model.to("cpu")
424
+ device = torch.device("cpu")
425
  except Exception:
426
  tts_model.to("cpu")
427
+ device = torch.device("cpu")
428
 
429
+ # Generate story text (LLM on CPU)
430
  history: List[Tuple[str, str | None]] = [(input_text, None)]
431
  full_story_text = "".join(
432
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
 
443
  if not any(c.isalnum() for c in sentence):
444
  continue
445
 
446
+ # Move cached latents to the same device as the model for this call
447
+ lat_dev = _latents_to_device(voice_latents[chatbot_role], device)
448
 
449
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev)
450
+ pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
451
+
452
+ # Optional noise reduction (best-effort, CPU)
453
  try:
454
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
455
  if data_s16.size > 0:
 
464
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
465
  results.append({"text": sentence, "audio": b64_wav})
466
 
467
+ # Return XTTS to CPU to release GPU instantly
468
  try:
469
  tts_model.to("cpu")
470
  except Exception:
 
472
 
473
  return results
474
 
 
475
  # ===================================================================================
476
  # 6) STARTUP: PRECACHE & UI
477
  # ===================================================================================
 
486
  ],
487
  outputs=gr.JSON(label="Story and Audio Output"),
488
  title="AI Storyteller with ZeroGPU",
489
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
490
+ flagging_mode="never", # avoid deprecated allow_flagging path
491
+ analytics_enabled=False,
492
  )
493
 
494
  if __name__ == "__main__":
495
+ print("===== Startup: pre-cache assets and preload models =====")
496
+ print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)")
497
+ precache_assets() # 1) download everything to disk
498
+ init_models_and_latents() # 2) load on CPU + compute voice latents on CPU
499
  print("Models and assets ready. Launching UI...")
500
 
501
  demo = build_ui()
502
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
 
 
 
requirements.txt CHANGED
@@ -23,7 +23,7 @@ langid
23
  nltk
24
  emoji
25
  ffmpeg-python
26
-
27
  # Japanese Text (optional)
28
  mecab-python3==1.0.9
29
  unidic-lite==1.0.8
 
23
  nltk
24
  emoji
25
  ffmpeg-python
26
+ transformers==4.36.2
27
  # Japanese Text (optional)
28
  mecab-python3==1.0.9
29
  unidic-lite==1.0.8