ruslanmv commited on
Commit
e09268d
·
1 Parent(s): 239225b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -87
app.py CHANGED
@@ -2,40 +2,55 @@
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
 
5
  import os
6
  import sys
 
7
  import base64
8
  import struct
9
  import textwrap
10
  import requests
11
  import atexit
12
- from typing import List, Dict, Tuple, Generator
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
19
- os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path
20
 
21
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
- # --- NumPy sanity with torch 2.2.x ---
26
  import numpy as _np
27
  if int(_np.__version__.split(".", 1)[0]) >= 2:
28
  raise RuntimeError(
29
- f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space."
30
  )
31
 
32
- # --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) ---
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
  import spaces # Required for ZeroGPU on HF
35
  except Exception:
36
  class _SpacesShim:
37
  def GPU(self, *args, **kwargs):
38
- def _wrap(fn): return fn
 
39
  return _wrap
40
  spaces = _SpacesShim()
41
 
@@ -47,7 +62,7 @@ import numpy as np
47
  from huggingface_hub import HfApi, hf_hub_download
48
  from llama_cpp import Llama
49
 
50
- # --- Audio decode via ffmpeg-python (no torchaudio.load) ---
51
  import ffmpeg
52
 
53
  # --- TTS Libraries ---
@@ -62,17 +77,30 @@ import langid
62
  import emoji
63
  import noisereduce as nr
64
 
 
65
  # ===================================================================================
66
  # 2) GLOBALS & HELPERS
67
  # ===================================================================================
68
 
69
- # NLTK data
70
- nltk.download("punkt", quiet=True)
 
 
 
 
 
 
 
 
71
 
72
- # Cached models & latents
 
 
73
  tts_model: Xtts | None = None
74
  llm_model: Llama | None = None
75
- voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
 
 
76
 
77
  # Config
78
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -82,9 +110,6 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
82
  SENTENCE_SPLIT_LENGTH = 250
83
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
84
 
85
- # IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it.
86
- USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true"
87
-
88
  # System prompts and roles
89
  default_system_message = (
90
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
@@ -98,7 +123,25 @@ ROLE_PROMPTS["Pirate"] = (
98
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
99
  )
100
 
101
- # ---------- small utils ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
103
  if pcm_data.startswith(b"RIFF"):
104
  return pcm_data
@@ -114,14 +157,19 @@ def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit
114
  return header + pcm_data
115
 
116
  def split_sentences(text: str, max_len: int) -> List[str]:
117
- sentences = nltk.sent_tokenize(text)
118
- out: List[str] = []
119
- for s in sentences:
120
- if len(s) > max_len:
121
- out.extend(textwrap.wrap(s, max_len, break_long_words=True))
 
 
 
 
122
  else:
123
- out.append(s)
124
- return out
 
125
 
126
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
127
  prompt = f"<|system|>\n{system_message}</s>"
@@ -131,6 +179,7 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
131
  prompt += f"<|user|>\n{message}</s><|assistant|>"
132
  return prompt
133
 
 
134
  # ---------- robust audio decode (mono via ffmpeg) ----------
135
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
136
  """
@@ -147,17 +196,23 @@ def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
147
  pcm = np.frombuffer(out, dtype=np.int16)
148
  if pcm.size == 0:
149
  raise RuntimeError("ffmpeg produced empty audio.")
150
- return (pcm.astype(np.float32) / 32767.0)
 
151
  except ffmpeg.Error as e:
152
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
153
 
154
- # ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ----------
 
155
  def _patched_load_audio(audiopath: str, load_sr: int):
156
  """
157
- Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr.
 
 
 
158
  """
159
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
160
- audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
 
161
  return audio
162
 
163
  xtts_module.load_audio = _patched_load_audio
@@ -167,12 +222,14 @@ try:
167
  except Exception:
168
  pass
169
 
 
170
  def _coqui_cache_dir() -> str:
171
- # Coqui cache default on Linux
172
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
173
 
 
174
  # ===================================================================================
175
- # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
176
  # ===================================================================================
177
 
178
  def precache_assets() -> None:
@@ -205,8 +262,9 @@ def precache_assets() -> None:
205
  except Exception as e:
206
  print(f"Warning: GGUF pre-cache error: {e}")
207
 
208
- def _load_xtts(device: str) -> Xtts:
209
- """Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility."""
 
210
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
211
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
212
  ModelManager().download_model(model_name) # idempotent
@@ -225,16 +283,20 @@ def _load_xtts(device: str) -> Xtts:
225
  print("XTTS model loaded.")
226
  return model
227
 
228
- def _load_llama_cpu_only() -> Llama:
229
- """Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly)."""
230
- print("Loading LLM (Zephyr GGUF) on CPU...")
 
 
 
 
231
  zephyr_model_path = hf_hub_download(
232
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
233
  filename="zephyr-7b-beta.Q5_K_M.gguf"
234
  )
235
  llm = Llama(
236
  model_path=zephyr_model_path,
237
- n_gpu_layers=0, # never touch CUDA at startup
238
  n_ctx=4096,
239
  n_batch=512,
240
  verbose=False
@@ -242,27 +304,27 @@ def _load_llama_cpu_only() -> Llama:
242
  print("LLM loaded (CPU).")
243
  return llm
244
 
 
245
  def init_models_and_latents() -> None:
246
  """
247
- Preload TTS and LLM on CPU and compute voice latents on CPU.
248
- This avoids any CUDA tensors outside the @spaces.GPU window.
249
  """
250
  global tts_model, llm_model, voice_latents
251
 
252
- # Always CPU here (ZeroGPU rule)
253
- target_device = "cpu"
254
-
255
  if tts_model is None:
256
- tts_model = _load_xtts(device=target_device)
257
- else:
258
- tts_model.to("cpu")
259
 
260
  if llm_model is None:
261
- llm_model = _load_llama_cpu_only()
262
 
263
- # Pre-compute latents once on CPU (uses our ffmpeg loader)
264
  if not voice_latents:
265
  print("Computing voice conditioning latents (CPU)...")
 
 
 
 
 
266
  with torch.no_grad():
267
  for role, filename in [
268
  ("Cloée", "cloee-1.wav"),
@@ -271,11 +333,21 @@ def init_models_and_latents() -> None:
271
  ("Thera", "thera-1.wav"),
272
  ]:
273
  path = os.path.join("voices", filename)
274
- # Returns torch tensors; keep them on CPU
275
- voice_latents[role] = tts_model.get_conditioning_latents(
276
  audio_path=path, gpt_cond_len=30, max_ref_length=60
277
  )
278
- print("Voice latents ready (CPU).")
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
281
  def _close_llm():
@@ -287,6 +359,7 @@ def _close_llm():
287
  pass
288
  atexit.register(_close_llm)
289
 
 
290
  # ===================================================================================
291
  # 4) INFERENCE HELPERS
292
  # ===================================================================================
@@ -294,17 +367,17 @@ atexit.register(_close_llm)
294
  def generate_text_stream(llm_instance: Llama, prompt: str,
295
  history: List[Tuple[str, str | None]],
296
  system_message_text: str) -> Generator[str, None, None]:
297
- formatted = format_prompt_zephyr(prompt, history, system_message_text)
298
  stream = llm_instance(
299
- formatted,
300
  temperature=0.7,
301
  max_tokens=512,
302
  top_p=0.95,
303
  stop=LLM_STOP_WORDS,
304
  stream=True
305
  )
306
- for resp in stream:
307
- ch = resp["choices"][0]["text"]
308
  try:
309
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
310
  except Exception:
@@ -313,31 +386,29 @@ def generate_text_stream(llm_instance: Llama, prompt: str,
313
  continue
314
  yield ch
315
 
316
- def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
317
- g, s = latents
318
- if isinstance(g, torch.Tensor):
319
- g = g.to(device)
320
- if isinstance(s, torch.Tensor):
321
- s = s.to(device)
322
- return g, s
323
 
324
  def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
325
- latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]:
326
- gpt_cond_latent, speaker_embedding = latents
 
 
 
327
  try:
328
  for chunk in tts_instance.inference_stream(
329
  text=text,
330
  language=language,
331
- gpt_cond_latent=gpt_cond_latent,
332
- speaker_embedding=speaker_embedding,
333
  temperature=0.85,
334
  ):
335
  if chunk is None:
336
  continue
337
- f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32)
338
- f32 = np.clip(f32, -1.0, 1.0)
 
339
  s16 = (f32 * 32767.0).astype(np.int16)
340
  yield s16.tobytes()
 
341
  except RuntimeError as e:
342
  print(f"Error during TTS inference: {e}")
343
  if "device-side assert" in str(e) and api:
@@ -347,34 +418,32 @@ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
347
  except Exception:
348
  pass
349
 
 
350
  # ===================================================================================
351
- # 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
352
  # ===================================================================================
353
 
354
- @spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call
355
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
356
  if secret_token_input != SECRET_TOKEN:
357
  raise gr.Error("Invalid secret token provided.")
358
  if not input_text:
359
  return []
360
 
361
- # Ensure models/latents exist (CPU)
362
  if tts_model is None or llm_model is None or not voice_latents:
363
  init_models_and_latents()
364
 
365
- # If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU.
366
  try:
367
  if torch.cuda.is_available():
368
  tts_model.to("cuda")
369
- device = torch.device("cuda")
370
  else:
371
  tts_model.to("cpu")
372
- device = torch.device("cpu")
373
  except Exception:
374
  tts_model.to("cpu")
375
- device = torch.device("cpu")
376
 
377
- # Generate story text (LLM on CPU)
378
  history: List[Tuple[str, str | None]] = [(input_text, None)]
379
  full_story_text = "".join(
380
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
@@ -391,13 +460,10 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
391
  if not any(c.isalnum() for c in sentence):
392
  continue
393
 
394
- # Move cached latents to the same device as the model for this call
395
- lat_dev = _latents_to_device(voice_latents[chatbot_role], device)
396
-
397
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev)
398
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
399
 
400
- # Optional noise reduction (best-effort, CPU)
401
  try:
402
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
403
  if data_s16.size > 0:
@@ -412,7 +478,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
412
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
413
  results.append({"text": sentence, "audio": b64_wav})
414
 
415
- # Return XTTS to CPU to release GPU instantly
416
  try:
417
  tts_model.to("cpu")
418
  except Exception:
@@ -420,6 +486,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
420
 
421
  return results
422
 
 
423
  # ===================================================================================
424
  # 6) STARTUP: PRECACHE & UI
425
  # ===================================================================================
@@ -434,17 +501,21 @@ def build_ui() -> gr.Interface:
434
  ],
435
  outputs=gr.JSON(label="Story and Audio Output"),
436
  title="AI Storyteller with ZeroGPU",
437
- description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
438
- allow_flagging="never",
439
- analytics_enabled=False,
440
  )
441
 
442
  if __name__ == "__main__":
443
- print("===== Startup: pre-cache assets and preload models =====")
444
- print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)")
445
- precache_assets() # 1) download everything to disk
446
- init_models_and_latents() # 2) load on CPU + compute voice latents on CPU
447
  print("Models and assets ready. Launching UI...")
448
 
449
  demo = build_ui()
450
- demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
 
 
 
 
2
  # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
+
6
  import os
7
  import sys
8
+ import re
9
  import base64
10
  import struct
11
  import textwrap
12
  import requests
13
  import atexit
14
+ from typing import List, Dict, Tuple, Generator, Any
15
 
16
  # --- Fast, safe defaults ---
17
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
18
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
19
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
20
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
21
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage quirks
22
 
23
  # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
24
  from dotenv import load_dotenv
25
  load_dotenv()
26
 
27
+ # --- NumPy sanity (Torch 2.2.x prefers NumPy 1.x) ---
28
  import numpy as _np
29
  if int(_np.__version__.split(".", 1)[0]) >= 2:
30
  raise RuntimeError(
31
+ f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4)."
32
  )
33
 
34
+ # --- Pandas compat shim (Gradio uses 'future.no_silent_downcasting') ---
35
+ try:
36
+ import pandas as pd
37
+ from pandas._config.config import register_option
38
+ try:
39
+ pd.get_option("future.no_silent_downcasting")
40
+ except Exception:
41
+ # Register the option if missing so pd.option_context(...) won't crash
42
+ register_option("future.no_silent_downcasting", False, validator=None, doc="compat shim for Gradio")
43
+ except Exception:
44
+ pd = None # ok
45
+
46
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
47
  try:
48
  import spaces # Required for ZeroGPU on HF
49
  except Exception:
50
  class _SpacesShim:
51
  def GPU(self, *args, **kwargs):
52
+ def _wrap(fn):
53
+ return fn
54
  return _wrap
55
  spaces = _SpacesShim()
56
 
 
62
  from huggingface_hub import HfApi, hf_hub_download
63
  from llama_cpp import Llama
64
 
65
+ # --- Audio decoding (pure ffmpeg-python; no torchaudio) ---
66
  import ffmpeg
67
 
68
  # --- TTS Libraries ---
 
77
  import emoji
78
  import noisereduce as nr
79
 
80
+
81
  # ===================================================================================
82
  # 2) GLOBALS & HELPERS
83
  # ===================================================================================
84
 
85
+ def _ensure_nltk() -> None:
86
+ # Newer NLTK splits data into 'punkt' and 'punkt_tab'
87
+ for pkg in ("punkt", "punkt_tab"):
88
+ try:
89
+ if pkg == "punkt":
90
+ nltk.data.find("tokenizers/punkt")
91
+ else:
92
+ nltk.data.find("tokenizers/punkt_tab")
93
+ except LookupError:
94
+ nltk.download(pkg, quiet=True)
95
 
96
+ _ensure_nltk()
97
+
98
+ # Models & caches
99
  tts_model: Xtts | None = None
100
  llm_model: Llama | None = None
101
+
102
+ # Store latents as NumPy on CPU for portability; convert to device at inference time
103
+ voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
104
 
105
  # Config
106
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
110
  SENTENCE_SPLIT_LENGTH = 250
111
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
112
 
 
 
 
113
  # System prompts and roles
114
  default_system_message = (
115
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
 
123
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
124
  )
125
 
126
+
127
+ # ---------- tiny utilities ----------
128
+ def _model_device(m: torch.nn.Module) -> torch.device:
129
+ try:
130
+ return next(m.parameters()).device
131
+ except StopIteration:
132
+ return torch.device("cpu")
133
+
134
+ def _to_device_float_tensor(x: Any, device: torch.device) -> torch.Tensor:
135
+ if isinstance(x, np.ndarray):
136
+ return torch.from_numpy(x).float().to(device)
137
+ if torch.is_tensor(x):
138
+ return x.to(device, dtype=torch.float32)
139
+ return torch.as_tensor(x, dtype=torch.float32, device=device)
140
+
141
+ def _latents_for_device(latents: Tuple[Any, Any], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
142
+ gpt_cond, spk = latents
143
+ return _to_device_float_tensor(gpt_cond, device), _to_device_float_tensor(spk, device)
144
+
145
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
146
  if pcm_data.startswith(b"RIFF"):
147
  return pcm_data
 
157
  return header + pcm_data
158
 
159
  def split_sentences(text: str, max_len: int) -> List[str]:
160
+ # Try NLTK; if it fails for any reason, fallback to a simple regex splitter.
161
+ try:
162
+ sentences = nltk.sent_tokenize(text)
163
+ except Exception:
164
+ sentences = re.split(r"(?<=[\.\!\?])\s+", text)
165
+ chunks: List[str] = []
166
+ for sent in sentences:
167
+ if len(sent) > max_len:
168
+ chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
169
  else:
170
+ if sent:
171
+ chunks.append(sent)
172
+ return chunks
173
 
174
  def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
175
  prompt = f"<|system|>\n{system_message}</s>"
 
179
  prompt += f"<|user|>\n{message}</s><|assistant|>"
180
  return prompt
181
 
182
+
183
  # ---------- robust audio decode (mono via ffmpeg) ----------
184
  def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
185
  """
 
196
  pcm = np.frombuffer(out, dtype=np.int16)
197
  if pcm.size == 0:
198
  raise RuntimeError("ffmpeg produced empty audio.")
199
+ wav = (pcm.astype(np.float32) / 32767.0)
200
+ return wav
201
  except ffmpeg.Error as e:
202
  raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
203
 
204
+
205
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
206
  def _patched_load_audio(audiopath: str, load_sr: int):
207
  """
208
+ Match XTTS' expected return type:
209
+ - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1],
210
+ already resampled to `load_sr`.
211
+ - DO NOT return (audio, sr) tuple.
212
  """
213
  wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
214
+ import torch as _torch # local import to avoid circularities
215
+ audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
216
  return audio
217
 
218
  xtts_module.load_audio = _patched_load_audio
 
222
  except Exception:
223
  pass
224
 
225
+
226
  def _coqui_cache_dir() -> str:
227
+ # Matches what TTS uses on Linux: ~/.local/share/tts
228
  return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
229
 
230
+
231
  # ===================================================================================
232
+ # 3) PRECACHE & MODEL LOADERS (CPU at startup to avoid ZeroGPU issues)
233
  # ===================================================================================
234
 
235
  def precache_assets() -> None:
 
262
  except Exception as e:
263
  print(f"Warning: GGUF pre-cache error: {e}")
264
 
265
+
266
+ def _load_xtts(device: str = "cpu") -> Xtts:
267
+ """Load XTTS from the local cache. Keep CPU at startup to avoid ZeroGPU device mixups."""
268
  print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
269
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
270
  ModelManager().download_model(model_name) # idempotent
 
283
  print("XTTS model loaded.")
284
  return model
285
 
286
+
287
+ def _load_llama() -> Llama:
288
+ """
289
+ Load Llama (Zephyr GGUF).
290
+ Keep simple & robust: default to CPU (works everywhere).
291
+ """
292
+ print("Loading LLM (Zephyr GGUF)...")
293
  zephyr_model_path = hf_hub_download(
294
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
295
  filename="zephyr-7b-beta.Q5_K_M.gguf"
296
  )
297
  llm = Llama(
298
  model_path=zephyr_model_path,
299
+ n_gpu_layers=0, # CPU-only for reliability across Spaces/ZeroGPU
300
  n_ctx=4096,
301
  n_batch=512,
302
  verbose=False
 
304
  print("LLM loaded (CPU).")
305
  return llm
306
 
307
+
308
  def init_models_and_latents() -> None:
309
  """
310
+ Preload models on CPU and compute voice latents on CPU.
311
+ This avoids ZeroGPU's "mixed device" errors from torchaudio-based resampling.
312
  """
313
  global tts_model, llm_model, voice_latents
314
 
 
 
 
315
  if tts_model is None:
316
+ tts_model = _load_xtts(device="cpu") # always CPU at startup
 
 
317
 
318
  if llm_model is None:
319
+ llm_model = _load_llama()
320
 
 
321
  if not voice_latents:
322
  print("Computing voice conditioning latents (CPU)...")
323
+ # Ensure the TTS model is on CPU while computing latents
324
+ orig_dev = _model_device(tts_model)
325
+ if orig_dev.type != "cpu":
326
+ tts_model.to("cpu")
327
+
328
  with torch.no_grad():
329
  for role, filename in [
330
  ("Cloée", "cloee-1.wav"),
 
333
  ("Thera", "thera-1.wav"),
334
  ]:
335
  path = os.path.join("voices", filename)
336
+ gpt_lat, spk_emb = tts_model.get_conditioning_latents(
 
337
  audio_path=path, gpt_cond_len=30, max_ref_length=60
338
  )
339
+ # Store as NumPy on CPU; convert to device on demand later
340
+ voice_latents[role] = (
341
+ gpt_lat.detach().cpu().numpy(),
342
+ spk_emb.detach().cpu().numpy(),
343
+ )
344
+
345
+ # Return model to original device (keep CPU at startup for safety)
346
+ if orig_dev.type != "cpu":
347
+ tts_model.to(orig_dev)
348
+
349
+ print("Voice latents ready.")
350
+
351
 
352
  # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
353
  def _close_llm():
 
359
  pass
360
  atexit.register(_close_llm)
361
 
362
+
363
  # ===================================================================================
364
  # 4) INFERENCE HELPERS
365
  # ===================================================================================
 
367
  def generate_text_stream(llm_instance: Llama, prompt: str,
368
  history: List[Tuple[str, str | None]],
369
  system_message_text: str) -> Generator[str, None, None]:
370
+ formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
371
  stream = llm_instance(
372
+ formatted_prompt,
373
  temperature=0.7,
374
  max_tokens=512,
375
  top_p=0.95,
376
  stop=LLM_STOP_WORDS,
377
  stream=True
378
  )
379
+ for response in stream:
380
+ ch = response["choices"][0]["text"]
381
  try:
382
  is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
383
  except Exception:
 
386
  continue
387
  yield ch
388
 
 
 
 
 
 
 
 
389
 
390
  def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
391
+ latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
392
+ # Convert stored CPU NumPy latents to tensors on the model's current device
393
+ device = _model_device(tts_instance)
394
+ gpt_cond_latent_t, speaker_embedding_t = _latents_for_device(latents, device)
395
+
396
  try:
397
  for chunk in tts_instance.inference_stream(
398
  text=text,
399
  language=language,
400
+ gpt_cond_latent=gpt_cond_latent_t,
401
+ speaker_embedding=speaker_embedding_t,
402
  temperature=0.85,
403
  ):
404
  if chunk is None:
405
  continue
406
+ # chunk: torch.FloatTensor [N] or [1, N], float32 in [-1, 1]
407
+ f32 = chunk.detach().cpu().numpy().squeeze()
408
+ f32 = np.clip(f32, -1.0, 1.0).astype(np.float32)
409
  s16 = (f32 * 32767.0).astype(np.int16)
410
  yield s16.tobytes()
411
+
412
  except RuntimeError as e:
413
  print(f"Error during TTS inference: {e}")
414
  if "device-side assert" in str(e) and api:
 
418
  except Exception:
419
  pass
420
 
421
+
422
  # ===================================================================================
423
+ # 5) ZERO-GPU ENTRYPOINT (safe on native GPU as well)
424
  # ===================================================================================
425
 
426
+ @spaces.GPU(duration=120) # GPU ops must occur inside this function when on ZeroGPU
427
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
428
  if secret_token_input != SECRET_TOKEN:
429
  raise gr.Error("Invalid secret token provided.")
430
  if not input_text:
431
  return []
432
 
433
+ # Ensure models/latents exist (loaded on CPU)
434
  if tts_model is None or llm_model is None or not voice_latents:
435
  init_models_and_latents()
436
 
437
+ # During the GPU window, move XTTS to CUDA if available; otherwise stay on CPU
438
  try:
439
  if torch.cuda.is_available():
440
  tts_model.to("cuda")
 
441
  else:
442
  tts_model.to("cpu")
 
443
  except Exception:
444
  tts_model.to("cpu")
 
445
 
446
+ # Generate story text (LLM kept CPU for simplicity & reliability)
447
  history: List[Tuple[str, str | None]] = [(input_text, None)]
448
  full_story_text = "".join(
449
  generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
 
460
  if not any(c.isalnum() for c in sentence):
461
  continue
462
 
463
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
 
 
 
464
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
465
 
466
+ # Optional noise reduction (best-effort)
467
  try:
468
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
469
  if data_s16.size > 0:
 
478
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
479
  results.append({"text": sentence, "audio": b64_wav})
480
 
481
+ # Leave model on CPU after the ZeroGPU window
482
  try:
483
  tts_model.to("cpu")
484
  except Exception:
 
486
 
487
  return results
488
 
489
+
490
  # ===================================================================================
491
  # 6) STARTUP: PRECACHE & UI
492
  # ===================================================================================
 
501
  ],
502
  outputs=gr.JSON(label="Story and Audio Output"),
503
  title="AI Storyteller with ZeroGPU",
504
+ description="Enter a prompt to generate a short story with voice narration. Uses GPU only within the generation call when available.",
505
+ flagging_mode="never",
506
+ # removed allow_flagging (deprecated) to silence warning
507
  )
508
 
509
  if __name__ == "__main__":
510
+ print("===== Startup: pre-cache assets and preload models (CPU) =====")
511
+ print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}")
512
+ precache_assets() # 1) download everything to disk
513
+ init_models_and_latents() # 2) load models on CPU + compute voice latents on CPU
514
  print("Models and assets ready. Launching UI...")
515
 
516
  demo = build_ui()
517
+ demo.queue().launch(
518
+ server_name="0.0.0.0",
519
+ server_port=int(os.getenv("PORT", "7860")),
520
+ ssr_mode=False, # disable experimental SSR noise
521
+ )