ruslanmv commited on
Commit
caa34c2
·
1 Parent(s): 9ae6548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -140
app.py CHANGED
@@ -2,70 +2,68 @@
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
 
5
  import os
6
  import sys
 
7
  import base64
8
  import struct
9
  import textwrap
10
  import requests
11
  import atexit
12
- from typing import List, Dict, Tuple, Generator
 
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
19
- os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path
20
 
21
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
- # --- NumPy sanity with torch 2.2.x ---
26
  import numpy as _np
27
  if int(_np.__version__.split(".", 1)[0]) >= 2:
28
  raise RuntimeError(
29
- f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space."
30
  )
31
 
32
- # --- Panda shim for Gradio on pandas<2.2 (no 'future.no_silent_downcasting') ---
33
  try:
34
  import pandas as pd
 
 
35
  try:
36
- with pd.option_context("future.no_silent_downcasting", True):
37
- pass
38
  except Exception:
39
- from contextlib import contextmanager
40
- _orig_option_context = pd.option_context
41
-
42
- @contextmanager
43
- def _patched_option_context(*args, **kwargs):
44
- # filter out unsupported option pairs
45
- filtered = []
46
- i = 0
47
- while i < len(args):
48
- key = args[i]
49
- val = args[i + 1] if i + 1 < len(args) else None
50
- if key == "future.no_silent_downcasting":
51
- i += 2
52
- continue
53
- filtered.extend([key, val])
54
- i += 2
55
- with _orig_option_context(*filtered, **kwargs):
56
- yield
57
-
58
- pd.option_context = _patched_option_context # type: ignore[attr-defined]
59
  except Exception:
60
- pd = None # noqa: N816
61
 
62
- # --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) ---
63
  try:
64
  import spaces # Required for ZeroGPU on HF
65
  except Exception:
66
  class _SpacesShim:
67
  def GPU(self, *args, **kwargs):
68
- def _wrap(fn): return fn
 
69
  return _wrap
70
  spaces = _SpacesShim()
71
 
@@ -77,7 +75,7 @@ import numpy as np
77
  from huggingface_hub import HfApi, hf_hub_download
78
  from llama_cpp import Llama
79
 
80
- # --- Audio decode via ffmpeg-python (no torchaudio.load) ---
81
  import ffmpeg
82
 
83
  # --- TTS Libraries ---
@@ -92,32 +90,30 @@ import langid
92
  import emoji
93
  import noisereduce as nr
94
 
 
95
  # ===================================================================================
96
  # 2) GLOBALS & HELPERS
97
  # ===================================================================================
98
 
99
- # Ensure NLTK resources (both 'punkt' and new 'punkt_tab' on newer NLTK)
100
- def _ensure_nltk():
101
- try:
102
- nltk.data.find("tokenizers/punkt")
103
- except LookupError:
104
- nltk.download("punkt", quiet=True)
105
- try:
106
- nltk.data.find("tokenizers/punkt_tab/english")
107
- except LookupError:
108
  try:
109
- nltk.download("punkt_tab", quiet=True)
110
- except Exception:
111
- # fallback: downloading 'punkt' already satisfies older versions
112
- pass
 
 
113
 
114
  _ensure_nltk()
115
 
116
- # Cached models & latents
117
  tts_model: Xtts | None = None
118
  llm_model: Llama | None = None
119
- # store as torch.Tensors (CPU at startup)
120
- voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
 
121
 
122
  # Config
123
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -127,10 +123,7 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
127
  SENTENCE_SPLIT_LENGTH = 250
128
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
129
 
130
- # IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it.
131
- USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true"
132
-
133
- # Roles & prompts
134
  default_system_message = (
135
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
136
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
@@ -143,7 +136,25 @@ ROLE_PROMPTS["Pirate"] = (
143
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
144
  )
145
 
146
- # ---------- small utils ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
148
  if pcm_data.startswith(b"RIFF"):
149
  return pcm_data
@@ -159,14 +170,19 @@ def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit
159
  return header + pcm_data
160
 
161
  def split_sentences(text: str, max_len: int) -> List[str]:
162
- sentences = nltk.sent_tokenize(text)
163
- out: List[str] = []
164
- for s in sentences:
165
- if len(s) > max_len:
166
- out.extend(textwrap.wrap(s, max_len, break_long_words=True))
 
 
 
 
167
  else:
168
- out.append(s)
169
- return out
 
170
 
171
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
172
  prompt = f"<|system|>\n{system_message}</s>"
@@ -176,6 +192,7 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
176
  prompt += f"<|user|>\n{message}</s><|assistant|>"
177
  return prompt
178
 
 
179
  # ---------- robust audio decode (mono via ffmpeg) ----------
180
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
181
  """
@@ -192,17 +209,23 @@ def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
192
  pcm = np.frombuffer(out, dtype=np.int16)
193
  if pcm.size == 0:
194
  raise RuntimeError("ffmpeg produced empty audio.")
195
- return (pcm.astype(np.float32) / 32767.0)
 
196
  except ffmpeg.Error as e:
197
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
198
 
199
- # ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ----------
 
200
  def _patched_load_audio(audiopath: str, load_sr: int):
201
  """
202
- Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr.
 
 
 
203
  """
204
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
205
- audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
 
206
  return audio
207
 
208
  xtts_module.load_audio = _patched_load_audio
@@ -212,12 +235,14 @@ try:
212
  except Exception:
213
  pass
214
 
 
215
  def _coqui_cache_dir() -> str:
216
- # Coqui cache default on Linux
217
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
218
 
 
219
  # ===================================================================================
220
- # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
221
  # ===================================================================================
222
 
223
  def precache_assets() -> None:
@@ -250,8 +275,9 @@ def precache_assets() -> None:
250
  except Exception as e:
251
  print(f"Warning: GGUF pre-cache error: {e}")
252
 
253
- def _load_xtts(device: str) -> Xtts:
254
- """Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility."""
 
255
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
256
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
257
  ModelManager().download_model(model_name) # idempotent
@@ -270,16 +296,20 @@ def _load_xtts(device: str) -> Xtts:
270
  print("XTTS model loaded.")
271
  return model
272
 
273
- def _load_llama_cpu_only() -> Llama:
274
- """Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly)."""
275
- print("Loading LLM (Zephyr GGUF) on CPU...")
 
 
 
 
276
  zephyr_model_path = hf_hub_download(
277
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
278
  filename="zephyr-7b-beta.Q5_K_M.gguf"
279
  )
280
  llm = Llama(
281
  model_path=zephyr_model_path,
282
- n_gpu_layers=0, # never touch CUDA at startup
283
  n_ctx=4096,
284
  n_batch=512,
285
  verbose=False
@@ -287,26 +317,27 @@ def _load_llama_cpu_only() -> Llama:
287
  print("LLM loaded (CPU).")
288
  return llm
289
 
 
290
  def init_models_and_latents() -> None:
291
  """
292
- Preload TTS and LLM on CPU and compute voice latents on CPU.
293
- This avoids any CUDA tensors outside the @spaces.GPU window.
294
  """
295
  global tts_model, llm_model, voice_latents
296
 
297
- target_device = "cpu" # FORCE CPU at startup for ZeroGPU compatibility
298
-
299
  if tts_model is None:
300
- tts_model = _load_xtts(device=target_device)
301
- else:
302
- tts_model.to("cpu")
303
 
304
  if llm_model is None:
305
- llm_model = _load_llama_cpu_only()
306
 
307
- # Pre-compute latents once on CPU (uses our ffmpeg loader)
308
  if not voice_latents:
309
  print("Computing voice conditioning latents (CPU)...")
 
 
 
 
 
310
  with torch.no_grad():
311
  for role, filename in [
312
  ("Cloée", "cloee-1.wav"),
@@ -315,10 +346,21 @@ def init_models_and_latents() -> None:
315
  ("Thera", "thera-1.wav"),
316
  ]:
317
  path = os.path.join("voices", filename)
318
- voice_latents[role] = tts_model.get_conditioning_latents(
319
  audio_path=path, gpt_cond_len=30, max_ref_length=60
320
  )
321
- print("Voice latents ready (CPU).")
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
324
  def _close_llm():
@@ -330,6 +372,7 @@ def _close_llm():
330
  pass
331
  atexit.register(_close_llm)
332
 
 
333
  # ===================================================================================
334
  # 4) INFERENCE HELPERS
335
  # ===================================================================================
@@ -337,17 +380,17 @@ atexit.register(_close_llm)
337
  def generate_text_stream(llm_instance: Llama, prompt: str,
338
  history: List[Tuple[str, str | None]],
339
  system_message_text: str) -> Generator[str, None, None]:
340
- formatted = format_prompt_zephyr(prompt, history, system_message_text)
341
  stream = llm_instance(
342
- formatted,
343
  temperature=0.7,
344
  max_tokens=512,
345
  top_p=0.95,
346
  stop=LLM_STOP_WORDS,
347
  stream=True
348
  )
349
- for resp in stream:
350
- ch = resp["choices"][0]["text"]
351
  try:
352
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
353
  except Exception:
@@ -356,68 +399,77 @@ def generate_text_stream(llm_instance: Llama, prompt: str,
356
  continue
357
  yield ch
358
 
359
- def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
360
- g, s = latents
361
- if isinstance(g, torch.Tensor):
362
- g = g.to(device)
363
- if isinstance(s, torch.Tensor):
364
- s = s.to(device)
365
- return g, s
366
-
367
- def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
368
- latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]:
369
- gpt_cond_latent, speaker_embedding = latents
370
- try:
371
- for chunk in tts_instance.inference_stream(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  text=text,
373
  language=language,
374
- gpt_cond_latent=gpt_cond_latent,
375
- speaker_embedding=speaker_embedding,
376
  temperature=0.85,
377
- ):
378
- if chunk is None:
379
- continue
380
- f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32)
381
- f32 = np.clip(f32, -1.0, 1.0)
382
- s16 = (f32 * 32767.0).astype(np.int16)
383
- yield s16.tobytes()
384
- except RuntimeError as e:
385
- print(f"Error during TTS inference: {e}")
386
- if "device-side assert" in str(e) and api:
387
- try:
388
- gr.Warning("Critical GPU error. Attempting to restart the Space...")
389
- api.restart_space(repo_id=repo_id)
390
- except Exception:
391
- pass
392
 
393
  # ===================================================================================
394
- # 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
395
  # ===================================================================================
396
 
397
- @spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call
398
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
399
  if secret_token_input != SECRET_TOKEN:
400
  raise gr.Error("Invalid secret token provided.")
401
  if not input_text:
402
  return []
403
 
404
- # Ensure models/latents exist (CPU)
405
  if tts_model is None or llm_model is None or not voice_latents:
406
  init_models_and_latents()
407
 
408
- # If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU.
409
  try:
410
  if torch.cuda.is_available():
411
  tts_model.to("cuda")
412
- device = torch.device("cuda")
413
  else:
414
  tts_model.to("cpu")
415
- device = torch.device("cpu")
416
  except Exception:
417
  tts_model.to("cpu")
418
- device = torch.device("cpu")
419
 
420
- # Generate story text (LLM on CPU)
421
  history: List[Tuple[str, str | None]] = [(input_text, None)]
422
  full_story_text = "".join(
423
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
@@ -434,13 +486,10 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
434
  if not any(c.isalnum() for c in sentence):
435
  continue
436
 
437
- # Move cached latents to the same device as the model for this call
438
- lat_dev = _latents_to_device(voice_latents[chatbot_role], device)
439
 
440
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev)
441
- pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
442
-
443
- # Optional noise reduction (best-effort, CPU)
444
  try:
445
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
446
  if data_s16.size > 0:
@@ -455,7 +504,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
455
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
456
  results.append({"text": sentence, "audio": b64_wav})
457
 
458
- # Return XTTS to CPU to release GPU instantly
459
  try:
460
  tts_model.to("cpu")
461
  except Exception:
@@ -463,6 +512,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
463
 
464
  return results
465
 
 
466
  # ===================================================================================
467
  # 6) STARTUP: PRECACHE & UI
468
  # ===================================================================================
@@ -477,17 +527,20 @@ def build_ui() -> gr.Interface:
477
  ],
478
  outputs=gr.JSON(label="Story and Audio Output"),
479
  title="AI Storyteller with ZeroGPU",
480
- description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
481
- flagging_mode="never", # avoid deprecated allow_flagging path
482
- analytics_enabled=False,
483
  )
484
 
485
  if __name__ == "__main__":
486
- print("===== Startup: pre-cache assets and preload models =====")
487
- print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)")
488
- precache_assets() # 1) download everything to disk
489
- init_models_and_latents() # 2) load on CPU + compute voice latents on CPU
490
  print("Models and assets ready. Launching UI...")
491
 
492
  demo = build_ui()
493
- demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
 
 
 
 
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
+
6
  import os
7
  import sys
8
+ import re
9
  import base64
10
  import struct
11
  import textwrap
12
  import requests
13
  import atexit
14
+ import inspect
15
+ from typing import List, Dict, Tuple, Generator, Any
16
 
17
  # --- Fast, safe defaults ---
18
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
19
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
20
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
21
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
22
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage quirks
23
 
24
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
25
  from dotenv import load_dotenv
26
  load_dotenv()
27
 
28
+ # --- NumPy sanity (Torch 2.2.x prefers NumPy 1.x) ---
29
  import numpy as _np
30
  if int(_np.__version__.split(".", 1)[0]) >= 2:
31
  raise RuntimeError(
32
+ f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4)."
33
  )
34
 
35
+ # --- Pandas compat shims (Gradio & mixed versions) ---
36
  try:
37
  import pandas as pd
38
+ from pandas._config.config import register_option
39
+ # Option shim
40
  try:
41
+ pd.get_option("future.no_silent_downcasting")
 
42
  except Exception:
43
+ register_option("future.no_silent_downcasting", False, validator=None, doc="compat shim for Gradio")
44
+ # infer_objects(copy=...) shim (older pandas lacks 'copy' kwarg)
45
+ if hasattr(pd, "DataFrame"):
46
+ try:
47
+ sig = inspect.signature(pd.DataFrame.infer_objects)
48
+ if "copy" not in sig.parameters:
49
+ _orig_infer_objects = pd.DataFrame.infer_objects
50
+ def _infer_objects_compat(self, *args, **kwargs):
51
+ kwargs.pop("copy", None)
52
+ return _orig_infer_objects(self, *args, **kwargs)
53
+ pd.DataFrame.infer_objects = _infer_objects_compat
54
+ except Exception:
55
+ pass
 
 
 
 
 
 
 
56
  except Exception:
57
+ pd = None # ok if pandas is unavailable
58
 
59
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
60
  try:
61
  import spaces # Required for ZeroGPU on HF
62
  except Exception:
63
  class _SpacesShim:
64
  def GPU(self, *args, **kwargs):
65
+ def _wrap(fn):
66
+ return fn
67
  return _wrap
68
  spaces = _SpacesShim()
69
 
 
75
  from huggingface_hub import HfApi, hf_hub_download
76
  from llama_cpp import Llama
77
 
78
+ # --- Audio decoding (pure ffmpeg-python; no torchaudio) ---
79
  import ffmpeg
80
 
81
  # --- TTS Libraries ---
 
90
  import emoji
91
  import noisereduce as nr
92
 
93
+
94
  # ===================================================================================
95
  # 2) GLOBALS & HELPERS
96
  # ===================================================================================
97
 
98
+ def _ensure_nltk() -> None:
99
+ # Newer NLTK splits data into 'punkt' and 'punkt_tab'
100
+ for pkg in ("punkt", "punkt_tab"):
 
 
 
 
 
 
101
  try:
102
+ if pkg == "punkt":
103
+ nltk.data.find("tokenizers/punkt")
104
+ else:
105
+ nltk.data.find("tokenizers/punkt_tab")
106
+ except LookupError:
107
+ nltk.download(pkg, quiet=True)
108
 
109
  _ensure_nltk()
110
 
111
+ # Models & caches
112
  tts_model: Xtts | None = None
113
  llm_model: Llama | None = None
114
+
115
+ # Store latents as NumPy on CPU for portability; convert to device at inference time
116
+ voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
117
 
118
  # Config
119
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
123
  SENTENCE_SPLIT_LENGTH = 250
124
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
125
 
126
+ # System prompts and roles
 
 
 
127
  default_system_message = (
128
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
129
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
 
136
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
137
  )
138
 
139
+
140
+ # ---------- tiny utilities ----------
141
+ def _model_device(m: torch.nn.Module) -> torch.device:
142
+ try:
143
+ return next(m.parameters()).device
144
+ except StopIteration:
145
+ return torch.device("cpu")
146
+
147
+ def _to_device_float_tensor(x: Any, device: torch.device) -> torch.Tensor:
148
+ if isinstance(x, np.ndarray):
149
+ return torch.from_numpy(x).float().to(device)
150
+ if torch.is_tensor(x):
151
+ return x.to(device, dtype=torch.float32)
152
+ return torch.as_tensor(x, dtype=torch.float32, device=device)
153
+
154
+ def _latents_for_device(latents: Tuple[Any, Any], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
155
+ gpt_cond, spk = latents
156
+ return _to_device_float_tensor(gpt_cond, device), _to_device_float_tensor(spk, device)
157
+
158
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
159
  if pcm_data.startswith(b"RIFF"):
160
  return pcm_data
 
170
  return header + pcm_data
171
 
172
  def split_sentences(text: str, max_len: int) -> List[str]:
173
+ # Try NLTK; if it fails for any reason, fallback to a simple regex splitter.
174
+ try:
175
+ sentences = nltk.sent_tokenize(text)
176
+ except Exception:
177
+ sentences = re.split(r"(?<=[\.\!\?])\s+", text)
178
+ chunks: List[str] = []
179
+ for sent in sentences:
180
+ if len(sent) > max_len:
181
+ chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
182
  else:
183
+ if sent:
184
+ chunks.append(sent)
185
+ return chunks
186
 
187
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
188
  prompt = f"<|system|>\n{system_message}</s>"
 
192
  prompt += f"<|user|>\n{message}</s><|assistant|>"
193
  return prompt
194
 
195
+
196
  # ---------- robust audio decode (mono via ffmpeg) ----------
197
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
198
  """
 
209
  pcm = np.frombuffer(out, dtype=np.int16)
210
  if pcm.size == 0:
211
  raise RuntimeError("ffmpeg produced empty audio.")
212
+ wav = (pcm.astype(np.float32) / 32767.0)
213
+ return wav
214
  except ffmpeg.Error as e:
215
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
216
 
217
+
218
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
219
  def _patched_load_audio(audiopath: str, load_sr: int):
220
  """
221
+ Match XTTS' expected return type:
222
+ - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1],
223
+ already resampled to `load_sr`.
224
+ - DO NOT return (audio, sr) tuple.
225
  """
226
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
227
+ import torch as _torch # local import to avoid circularities
228
+ audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
229
  return audio
230
 
231
  xtts_module.load_audio = _patched_load_audio
 
235
  except Exception:
236
  pass
237
 
238
+
239
  def _coqui_cache_dir() -> str:
240
+ # Matches what TTS uses on Linux: ~/.local/share/tts
241
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
242
 
243
+
244
  # ===================================================================================
245
+ # 3) PRECACHE & MODEL LOADERS (CPU at startup to avoid ZeroGPU issues)
246
  # ===================================================================================
247
 
248
  def precache_assets() -> None:
 
275
  except Exception as e:
276
  print(f"Warning: GGUF pre-cache error: {e}")
277
 
278
+
279
+ def _load_xtts(device: str = "cpu") -> Xtts:
280
+ """Load XTTS from the local cache. Keep CPU at startup to avoid ZeroGPU device mixups."""
281
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
282
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
283
  ModelManager().download_model(model_name) # idempotent
 
296
  print("XTTS model loaded.")
297
  return model
298
 
299
+
300
+ def _load_llama() -> Llama:
301
+ """
302
+ Load Llama (Zephyr GGUF).
303
+ Keep simple & robust: default to CPU (works everywhere).
304
+ """
305
+ print("Loading LLM (Zephyr GGUF)...")
306
  zephyr_model_path = hf_hub_download(
307
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
308
  filename="zephyr-7b-beta.Q5_K_M.gguf"
309
  )
310
  llm = Llama(
311
  model_path=zephyr_model_path,
312
+ n_gpu_layers=0, # CPU-only for reliability across Spaces/ZeroGPU
313
  n_ctx=4096,
314
  n_batch=512,
315
  verbose=False
 
317
  print("LLM loaded (CPU).")
318
  return llm
319
 
320
+
321
  def init_models_and_latents() -> None:
322
  """
323
+ Preload models on CPU and compute voice latents on CPU.
324
+ This avoids ZeroGPU's "mixed device" errors from torchaudio-based resampling.
325
  """
326
  global tts_model, llm_model, voice_latents
327
 
 
 
328
  if tts_model is None:
329
+ tts_model = _load_xtts(device="cpu") # always CPU at startup
 
 
330
 
331
  if llm_model is None:
332
+ llm_model = _load_llama()
333
 
 
334
  if not voice_latents:
335
  print("Computing voice conditioning latents (CPU)...")
336
+ # Ensure the TTS model is on CPU while computing latents
337
+ orig_dev = _model_device(tts_model)
338
+ if orig_dev.type != "cpu":
339
+ tts_model.to("cpu")
340
+
341
  with torch.no_grad():
342
  for role, filename in [
343
  ("Cloée", "cloee-1.wav"),
 
346
  ("Thera", "thera-1.wav"),
347
  ]:
348
  path = os.path.join("voices", filename)
349
+ gpt_lat, spk_emb = tts_model.get_conditioning_latents(
350
  audio_path=path, gpt_cond_len=30, max_ref_length=60
351
  )
352
+ # Store as NumPy on CPU; convert to device on demand later
353
+ voice_latents[role] = (
354
+ gpt_lat.detach().cpu().numpy(),
355
+ spk_emb.detach().cpu().numpy(),
356
+ )
357
+
358
+ # Return model to original device (keep CPU at startup for safety)
359
+ if orig_dev.type != "cpu":
360
+ tts_model.to(orig_dev)
361
+
362
+ print("Voice latents ready.")
363
+
364
 
365
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
366
  def _close_llm():
 
372
  pass
373
  atexit.register(_close_llm)
374
 
375
+
376
  # ===================================================================================
377
  # 4) INFERENCE HELPERS
378
  # ===================================================================================
 
380
  def generate_text_stream(llm_instance: Llama, prompt: str,
381
  history: List[Tuple[str, str | None]],
382
  system_message_text: str) -> Generator[str, None, None]:
383
+ formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
384
  stream = llm_instance(
385
+ formatted_prompt,
386
  temperature=0.7,
387
  max_tokens=512,
388
  top_p=0.95,
389
  stop=LLM_STOP_WORDS,
390
  stream=True
391
  )
392
+ for response in stream:
393
+ ch = response["choices"][0]["text"]
394
  try:
395
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
396
  except Exception:
 
399
  continue
400
  yield ch
401
 
402
+
403
+ def _extract_waveform_to_numpy(wav_any: Any) -> np.ndarray:
404
+ """
405
+ Normalize various XTTS inference() return shapes/types to 1-D float32 numpy in [-1, 1].
406
+ """
407
+ if isinstance(wav_any, dict):
408
+ for k in ("wav", "audio", "samples", "waveform"):
409
+ if k in wav_any:
410
+ wav_any = wav_any[k]
411
+ break
412
+ if torch.is_tensor(wav_any):
413
+ arr = wav_any.detach().cpu().numpy()
414
+ else:
415
+ arr = np.asarray(wav_any)
416
+ arr = np.squeeze(arr).astype(np.float32)
417
+ # If not already normalized, attempt to scale if max > 1 (heuristic)
418
+ maxabs = np.max(np.abs(arr)) if arr.size else 1.0
419
+ if maxabs > 1.5: # likely int16 or higher-amplitude float
420
+ arr = arr / 32767.0
421
+ arr = np.clip(arr, -1.0, 1.0)
422
+ return arr
423
+
424
+
425
+ def synthesize_sentence_pcm16(tts_instance: Xtts, text: str, language: str,
426
+ latents: Tuple[np.ndarray, np.ndarray]) -> bytes:
427
+ """
428
+ Use non-streaming XTTS inference() to avoid GPT2InferenceModel streaming bug.
429
+ Returns PCM16 bytes at 24 kHz mono.
430
+ """
431
+ device = _model_device(tts_instance)
432
+ gpt_cond_latent_t, speaker_embedding_t = _latents_for_device(latents, device)
433
+
434
+ with torch.no_grad():
435
+ out = tts_instance.inference(
436
  text=text,
437
  language=language,
438
+ gpt_cond_latent=gpt_cond_latent_t,
439
+ speaker_embedding=speaker_embedding_t,
440
  temperature=0.85,
441
+ )
442
+
443
+ f32 = _extract_waveform_to_numpy(out)
444
+ s16 = (np.clip(f32, -1.0, 1.0) * 32767.0).astype(np.int16)
445
+ return s16.tobytes()
446
+
 
 
 
 
 
 
 
 
 
447
 
448
  # ===================================================================================
449
+ # 5) ZERO-GPU ENTRYPOINT (safe on native GPU as well)
450
  # ===================================================================================
451
 
452
+ @spaces.GPU(duration=120) # GPU ops happen inside when on ZeroGPU
453
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
454
  if secret_token_input != SECRET_TOKEN:
455
  raise gr.Error("Invalid secret token provided.")
456
  if not input_text:
457
  return []
458
 
459
+ # Ensure models/latents exist (loaded on CPU)
460
  if tts_model is None or llm_model is None or not voice_latents:
461
  init_models_and_latents()
462
 
463
+ # During the GPU window, move XTTS to CUDA if available; otherwise stay on CPU
464
  try:
465
  if torch.cuda.is_available():
466
  tts_model.to("cuda")
 
467
  else:
468
  tts_model.to("cpu")
 
469
  except Exception:
470
  tts_model.to("cpu")
 
471
 
472
+ # Generate story text (LLM kept CPU for simplicity & reliability)
473
  history: List[Tuple[str, str | None]] = [(input_text, None)]
474
  full_story_text = "".join(
475
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
 
486
  if not any(c.isalnum() for c in sentence):
487
  continue
488
 
489
+ # Synthesize whole sentence (non-stream) to avoid streaming bug
490
+ pcm_data = synthesize_sentence_pcm16(tts_model, sentence, lang, voice_latents[chatbot_role])
491
 
492
+ # Optional noise reduction (best-effort)
 
 
 
493
  try:
494
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
495
  if data_s16.size > 0:
 
504
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
505
  results.append({"text": sentence, "audio": b64_wav})
506
 
507
+ # Leave model on CPU after the ZeroGPU window
508
  try:
509
  tts_model.to("cpu")
510
  except Exception:
 
512
 
513
  return results
514
 
515
+
516
  # ===================================================================================
517
  # 6) STARTUP: PRECACHE & UI
518
  # ===================================================================================
 
527
  ],
528
  outputs=gr.JSON(label="Story and Audio Output"),
529
  title="AI Storyteller with ZeroGPU",
530
+ description="Enter a prompt to generate a short story with voice narration. Uses GPU only within the generation call when available.",
531
+ flagging_mode="never",
 
532
  )
533
 
534
  if __name__ == "__main__":
535
+ print("===== Startup: pre-cache assets and preload models (CPU) =====")
536
+ print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}")
537
+ precache_assets() # 1) download everything to disk
538
+ init_models_and_latents() # 2) load models on CPU + compute voice latents on CPU
539
  print("Models and assets ready. Launching UI...")
540
 
541
  demo = build_ui()
542
+ demo.queue().launch(
543
+ server_name="0.0.0.0",
544
+ server_port=int(os.getenv("PORT", "7860")),
545
+ ssr_mode=False, # disable experimental SSR noise
546
+ )