ruslanmv commited on
Commit
c3f9f3a
·
1 Parent(s): 69e6077

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -120
app.py CHANGED
@@ -3,28 +3,36 @@
3
  # ===================================================================================
4
  from __future__ import annotations
5
  import os
 
6
  import base64
7
  import struct
8
  import textwrap
9
  import requests
10
  import atexit
11
- import tempfile
12
  from typing import List, Dict, Tuple, Generator
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
 
19
 
20
- # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
21
  from dotenv import load_dotenv
22
  load_dotenv()
23
 
24
- # --- Hugging Face Spaces & ZeroGPU ---
 
 
 
 
 
 
 
25
  try:
26
- import spaces
27
- except ImportError:
28
  class _SpacesShim:
29
  def GPU(self, *args, **kwargs):
30
  def _wrap(fn):
@@ -34,19 +42,20 @@ except ImportError:
34
 
35
  import gradio as gr
36
 
37
- # --- Core ML & Data Libraries ---
38
  import torch
39
  import numpy as np
40
  from huggingface_hub import HfApi, hf_hub_download
41
  from llama_cpp import Llama
42
- import torchaudio
43
- import soundfile as sf
 
44
 
45
  # --- TTS Libraries ---
46
  from TTS.tts.configs.xtts_config import XttsConfig
47
  from TTS.tts.models.xtts import Xtts
48
  from TTS.utils.manage import ModelManager
49
- from TTS.utils.generic_utils import get_user_data_dir
50
 
51
  # --- Text & Audio Processing ---
52
  import nltk
@@ -58,12 +67,15 @@ import noisereduce as nr
58
  # 2) GLOBALS & HELPERS
59
  # ===================================================================================
60
 
 
61
  nltk.download("punkt", quiet=True)
62
 
 
63
  tts_model: Xtts | None = None
64
  llm_model: Llama | None = None
65
- voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
66
 
 
67
  HF_TOKEN = os.environ.get("HF_TOKEN")
68
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
69
  repo_id = "ruslanmv/ai-story-server"
@@ -71,6 +83,7 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
71
  SENTENCE_SPLIT_LENGTH = 250
72
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
73
 
 
74
  default_system_message = (
75
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
76
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
@@ -83,6 +96,7 @@ ROLE_PROMPTS["Pirate"] = (
83
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
84
  )
85
 
 
86
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
87
  if pcm_data.startswith(b"RIFF"):
88
  return pcm_data
@@ -115,11 +129,44 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
115
  prompt += f"<|user|>\n{message}</s><|assistant|>"
116
  return prompt
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # ===================================================================================
119
  # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
120
  # ===================================================================================
121
 
122
  def precache_assets() -> None:
 
 
123
  print("Pre-caching voice files...")
124
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
125
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
@@ -135,35 +182,43 @@ def precache_assets() -> None:
135
  except Exception as e:
136
  print(f"Failed to download {name}: {e}")
137
 
 
138
  print("Pre-caching XTTS v2 model files...")
139
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
140
 
 
141
  print("Pre-caching Zephyr GGUF...")
142
  try:
143
  hf_hub_download(
144
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
145
  filename="zephyr-7b-beta.Q5_K_M.gguf",
146
- local_dir_use_symlinks=False,
147
  )
148
  except Exception as e:
149
  print(f"Warning: GGUF pre-cache error: {e}")
150
 
151
  def _load_xtts(device: str) -> Xtts:
 
152
  print("Loading Coqui XTTS V2 model (CPU first)...")
153
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
154
- model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
155
- if not os.path.exists(model_dir):
156
- ModelManager().download_model(model_name)
157
 
158
  cfg = XttsConfig()
159
  cfg.load_json(os.path.join(model_dir, "config.json"))
160
  model = Xtts.init_from_config(cfg)
161
- model.load_checkpoint(cfg, checkpoint_dir=model_dir, eval=True, use_deepspeed=False)
 
 
 
 
 
162
  model.to(device)
163
  print("XTTS model loaded.")
164
  return model
165
 
166
  def _load_llama() -> Llama:
 
167
  print("Loading LLM (Zephyr GGUF) on CPU...")
168
  zephyr_model_path = hf_hub_download(
169
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
@@ -171,126 +226,143 @@ def _load_llama() -> Llama:
171
  )
172
  llm = Llama(
173
  model_path=zephyr_model_path,
174
- n_gpu_layers=0, n_ctx=4096, n_batch=512, verbose=False
 
 
 
175
  )
176
  print("LLM loaded (CPU).")
177
  return llm
178
 
179
- def load_and_resample_audio(path: str, target_sr: int = 24000) -> torch.Tensor:
180
- """Loads audio, converts to a Torch tensor, and resamples if needed."""
181
- try:
182
- audio_np, original_sr = sf.read(path, dtype='float32')
183
- if audio_np.ndim > 1:
184
- audio_np = np.mean(audio_np, axis=1)
185
- waveform = torch.from_numpy(audio_np).float()
186
-
187
- if original_sr != target_sr:
188
- print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.")
189
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
190
- waveform = resampler(waveform)
191
-
192
- return waveform.unsqueeze(0)
193
- except Exception as e:
194
- print(f"Error loading audio file {path}: {e}")
195
- raise
196
-
197
  def init_models_and_latents() -> None:
198
- """Preload models and compute voice latents, using temporary files for compatibility."""
199
  global tts_model, llm_model, voice_latents
200
 
201
  if tts_model is None:
202
- tts_model = _load_xtts(device="cpu")
203
 
204
  if llm_model is None:
205
  llm_model = _load_llama()
206
 
 
207
  if not voice_latents:
208
  print("Computing voice conditioning latents...")
209
- with tempfile.TemporaryDirectory() as temp_dir:
210
- voice_files = {
211
- "Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav",
212
- "Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav",
213
- }
214
- for role, filename in voice_files.items():
215
- original_path = os.path.join("voices", filename)
216
- resampled_waveform = load_and_resample_audio(original_path)
217
-
218
- temp_path = os.path.join(temp_dir, f"resampled_{filename}")
219
-
220
- # --- FIX: Replace torchaudio.save with the more stable soundfile.write ---
221
- numpy_waveform = resampled_waveform.squeeze(0).cpu().numpy()
222
- sf.write(temp_path, numpy_waveform, 24000)
223
-
224
- voice_latents[role] = tts_model.get_conditioning_latents(
225
- audio_path=temp_path,
226
- gpt_cond_len=30,
227
- max_ref_length=60
228
- )
229
  print("Voice latents ready.")
230
 
 
231
  def _close_llm():
232
  global llm_model
233
- if llm_model is not None:
234
- del llm_model
 
 
 
235
  atexit.register(_close_llm)
236
 
237
  # ===================================================================================
238
  # 4) INFERENCE HELPERS
239
  # ===================================================================================
240
 
241
- def generate_text_stream(llm_instance: Llama, prompt: str, history: List, sys_prompt: str) -> Generator[str, None, None]:
242
- formatted_prompt = format_prompt_zephyr(prompt, history, sys_prompt)
 
 
243
  stream = llm_instance(
244
- formatted_prompt, temperature=0.7, max_tokens=512, top_p=0.95, stop=LLM_STOP_WORDS, stream=True
 
 
 
 
 
245
  )
246
  for response in stream:
247
- yield response["choices"][0]["text"]
248
-
249
- def generate_audio_stream(tts_instance: Xtts, text: str, lang: str, latents: Tuple) -> Generator[bytes, None, None]:
 
 
 
 
 
 
 
 
250
  gpt_cond_latent, speaker_embedding = latents
251
- for chunk in tts_instance.inference_stream(
252
- text, lang, gpt_cond_latent, speaker_embedding, temperature=0.85,
253
- ):
254
- if chunk is not None:
255
- yield chunk.detach().cpu().numpy().squeeze().tobytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  # ===================================================================================
258
  # 5) ZERO-GPU ENTRYPOINT
259
  # ===================================================================================
260
 
261
- @spaces.GPU(duration=120)
262
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
263
  if secret_token_input != SECRET_TOKEN:
264
  raise gr.Error("Invalid secret token provided.")
265
  if not input_text:
266
  return []
267
 
268
- if tts_model is None or llm_model is None:
269
- raise gr.Error("Models not initialized. Please restart the Space.")
 
270
 
 
271
  try:
272
  if torch.cuda.is_available():
273
  tts_model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- history: List[Tuple[str, str | None]] = [(input_text, None)]
276
- full_story_text = "".join(
277
- generate_text_stream(llm_model, history[-1][0], history[:-1], ROLE_PROMPTS[chatbot_role])
278
- ).strip()
279
-
280
- if not full_story_text:
281
- return []
282
 
283
- sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
284
- lang = langid.classify(sentences[0])[0] if sentences else "en"
285
- results: List[Dict[str, str]] = []
 
286
 
287
- for sentence in sentences:
288
- if not any(c.isalnum() for c in sentence):
289
- continue
290
-
291
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
292
- pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
293
 
 
 
294
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
295
  if data_s16.size > 0:
296
  float_data = data_s16.astype(np.float32) / 32767.0
@@ -298,45 +370,43 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
298
  final_pcm = (reduced * 32767).astype(np.int16).tobytes()
299
  else:
300
  final_pcm = pcm_data
301
-
302
- b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
303
- results.append({"text": sentence, "audio": b64_wav})
304
-
305
- return results
306
-
307
- finally:
308
- if tts_model is not None:
309
- tts_model.to("cpu")
 
 
 
 
310
 
311
  # ===================================================================================
312
  # 6) STARTUP: PRECACHE & UI
313
  # ===================================================================================
314
 
315
- def build_ui() -> gr.Blocks:
316
- with gr.Blocks() as demo:
317
- gr.Markdown("# AI Storyteller with ZeroGPU")
318
- gr.Markdown("Enter a prompt to generate a short story with voice narration using on-demand GPU.")
319
-
320
- with gr.Row():
321
- secret_token = gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN)
322
- storyteller = gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée")
323
-
324
- prompt = gr.Textbox(placeholder="What should the story be about?", label="Story Prompt")
325
- output = gr.JSON(label="Story and Audio Output")
326
-
327
- prompt.submit(
328
- fn=generate_story_and_speech,
329
- inputs=[secret_token, prompt, storyteller],
330
- outputs=output,
331
- )
332
-
333
- return demo
334
 
335
  if __name__ == "__main__":
336
  print("===== Startup: pre-cache assets and preload models =====")
337
- precache_assets()
338
- init_models_and_latents()
339
  print("Models and assets ready. Launching UI...")
340
 
341
  demo = build_ui()
342
- demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
3
  # ===================================================================================
4
  from __future__ import annotations
5
  import os
6
+ import sys
7
  import base64
8
  import struct
9
  import textwrap
10
  import requests
11
  import atexit
 
12
  from typing import List, Dict, Tuple, Generator
13
 
14
  # --- Fast, safe defaults ---
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
19
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio ffmpeg path entirely
20
 
21
+ # --- .env early (HF_TOKEN / SECRET_TOKEN) ---
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
 
25
+ # --- NumPy sanity (Torch 2.2.2 expects NumPy 1.x in your stack) ---
26
+ import numpy as _np
27
+ if int(_np.__version__.split(".", 1)[0]) >= 2:
28
+ raise RuntimeError(
29
+ f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4) for this Space."
30
+ )
31
+
32
+ # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) ---
33
  try:
34
+ import spaces # Required for ZeroGPU on HF
35
+ except Exception:
36
  class _SpacesShim:
37
  def GPU(self, *args, **kwargs):
38
  def _wrap(fn):
 
42
 
43
  import gradio as gr
44
 
45
+ # --- Core ML & Data Libraries (after spaces import) ---
46
  import torch
47
  import numpy as np
48
  from huggingface_hub import HfApi, hf_hub_download
49
  from llama_cpp import Llama
50
+
51
+ # --- Audio decoding (we'll use ffmpeg-python to avoid torchaudio/torio) ---
52
+ import ffmpeg
53
 
54
  # --- TTS Libraries ---
55
  from TTS.tts.configs.xtts_config import XttsConfig
56
  from TTS.tts.models.xtts import Xtts
57
  from TTS.utils.manage import ModelManager
58
+ import TTS.tts.models.xtts as xtts_module # for monkey-patching load_audio
59
 
60
  # --- Text & Audio Processing ---
61
  import nltk
 
67
  # 2) GLOBALS & HELPERS
68
  # ===================================================================================
69
 
70
+ # NLTK data
71
  nltk.download("punkt", quiet=True)
72
 
73
+ # Cached models & latents
74
  tts_model: Xtts | None = None
75
  llm_model: Llama | None = None
76
+ voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
77
 
78
+ # Config
79
  HF_TOKEN = os.environ.get("HF_TOKEN")
80
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
81
  repo_id = "ruslanmv/ai-story-server"
 
83
  SENTENCE_SPLIT_LENGTH = 250
84
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
85
 
86
+ # System prompts and roles
87
  default_system_message = (
88
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
89
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
 
96
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
97
  )
98
 
99
+ # ---------- small utils ----------
100
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
101
  if pcm_data.startswith(b"RIFF"):
102
  return pcm_data
 
129
  prompt += f"<|user|>\n{message}</s><|assistant|>"
130
  return prompt
131
 
132
+ # ---------- robust audio decode (24k mono via ffmpeg) ----------
133
+ def _decode_audio_ffmpeg_to_24k_mono(path: str) -> Tuple[np.ndarray, int]:
134
+ """Return float32 waveform in [-1,1], 24 kHz mono."""
135
+ try:
136
+ out, _ = (
137
+ ffmpeg
138
+ .input(path)
139
+ .output("pipe:", format="s16le", acodec="pcm_s16le", ac=1, ar=24000)
140
+ .run(capture_stdout=True, capture_stderr=True, cmd="ffmpeg")
141
+ )
142
+ pcm = np.frombuffer(out, dtype=np.int16)
143
+ if pcm.size == 0:
144
+ raise RuntimeError("ffmpeg produced empty audio.")
145
+ wav = (pcm.astype(np.float32) / 32767.0).copy()
146
+ return wav, 24000
147
+ except ffmpeg.Error as e:
148
+ raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
149
+
150
+ # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ----------
151
+ def _patched_load_audio(audiopath: str, load_sr: int):
152
+ wav, sr = _decode_audio_ffmpeg_to_24k_mono(audiopath)
153
+ # XTTS expects (audio, sr) and will handle truncation/conditioning windows.
154
+ return wav, sr
155
+
156
+ xtts_module.load_audio = _patched_load_audio # <- critical fix
157
+
158
+ # ---------- where Coqui caches models (avoid get_user_data_dir import) ----------
159
+ def _coqui_cache_dir() -> str:
160
+ # Matches what TTS uses on Linux: ~/.local/share/tts
161
+ return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
162
+
163
  # ===================================================================================
164
  # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
165
  # ===================================================================================
166
 
167
  def precache_assets() -> None:
168
+ """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
169
+ # Voices
170
  print("Pre-caching voice files...")
171
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
172
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
 
182
  except Exception as e:
183
  print(f"Failed to download {name}: {e}")
184
 
185
+ # XTTS model files
186
  print("Pre-caching XTTS v2 model files...")
187
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
188
 
189
+ # LLM GGUF
190
  print("Pre-caching Zephyr GGUF...")
191
  try:
192
  hf_hub_download(
193
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
194
  filename="zephyr-7b-beta.Q5_K_M.gguf",
195
+ force_download=False
196
  )
197
  except Exception as e:
198
  print(f"Warning: GGUF pre-cache error: {e}")
199
 
200
  def _load_xtts(device: str) -> Xtts:
201
+ """Load XTTS from the local cache. Use checkpoint_dir to avoid None path bugs."""
202
  print("Loading Coqui XTTS V2 model (CPU first)...")
203
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
204
+ ModelManager().download_model(model_name) # idempotent
205
+ model_dir = os.path.join(_coqui_cache_dir(), model_name.replace("/", "--"))
 
206
 
207
  cfg = XttsConfig()
208
  cfg.load_json(os.path.join(model_dir, "config.json"))
209
  model = Xtts.init_from_config(cfg)
210
+ model.load_checkpoint(
211
+ cfg,
212
+ checkpoint_dir=model_dir,
213
+ eval=True,
214
+ use_deepspeed=False,
215
+ )
216
  model.to(device)
217
  print("XTTS model loaded.")
218
  return model
219
 
220
  def _load_llama() -> Llama:
221
+ """Load Llama (Zephyr GGUF) on CPU so it's ready immediately."""
222
  print("Loading LLM (Zephyr GGUF) on CPU...")
223
  zephyr_model_path = hf_hub_download(
224
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
 
226
  )
227
  llm = Llama(
228
  model_path=zephyr_model_path,
229
+ n_gpu_layers=0, # CPU by default
230
+ n_ctx=4096,
231
+ n_batch=512,
232
+ verbose=False
233
  )
234
  print("LLM loaded (CPU).")
235
  return llm
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def init_models_and_latents() -> None:
238
+ """Preload TTS and LLM on CPU and compute voice latents once (using patched audio loader)."""
239
  global tts_model, llm_model, voice_latents
240
 
241
  if tts_model is None:
242
+ tts_model = _load_xtts(device="cpu") # keep on CPU at startup
243
 
244
  if llm_model is None:
245
  llm_model = _load_llama()
246
 
247
+ # Pre-compute latents once (CPU OK); uses patched loader (ffmpeg) under the hood
248
  if not voice_latents:
249
  print("Computing voice conditioning latents...")
250
+ for role, filename in [
251
+ ("Cloée", "cloee-1.wav"),
252
+ ("Julian", "julian-bedtime-style-1.wav"),
253
+ ("Pirate", "pirate_by_coqui.wav"),
254
+ ("Thera", "thera-1.wav"),
255
+ ]:
256
+ path = os.path.join("voices", filename)
257
+ voice_latents[role] = tts_model.get_conditioning_latents(
258
+ audio_path=path, gpt_cond_len=30, max_ref_length=60
259
+ )
 
 
 
 
 
 
 
 
 
 
260
  print("Voice latents ready.")
261
 
262
+ # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
263
  def _close_llm():
264
  global llm_model
265
+ try:
266
+ if llm_model is not None:
267
+ llm_model.close()
268
+ except Exception:
269
+ pass
270
  atexit.register(_close_llm)
271
 
272
  # ===================================================================================
273
  # 4) INFERENCE HELPERS
274
  # ===================================================================================
275
 
276
+ def generate_text_stream(llm_instance: Llama, prompt: str,
277
+ history: List[Tuple[str, str | None]],
278
+ system_message_text: str) -> Generator[str, None, None]:
279
+ formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
280
  stream = llm_instance(
281
+ formatted_prompt,
282
+ temperature=0.7,
283
+ max_tokens=512,
284
+ top_p=0.95,
285
+ stop=LLM_STOP_WORDS,
286
+ stream=True
287
  )
288
  for response in stream:
289
+ ch = response["choices"][0]["text"]
290
+ try:
291
+ is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
292
+ except Exception:
293
+ is_single_emoji = False
294
+ if "<|user|>" in ch or is_single_emoji:
295
+ continue
296
+ yield ch
297
+
298
+ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
299
+ latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
300
  gpt_cond_latent, speaker_embedding = latents
301
+ try:
302
+ for chunk in tts_instance.inference_stream(
303
+ text=text,
304
+ language=language,
305
+ gpt_cond_latent=gpt_cond_latent,
306
+ speaker_embedding=speaker_embedding,
307
+ temperature=0.85,
308
+ ):
309
+ if chunk is not None:
310
+ yield chunk.detach().cpu().numpy().squeeze().tobytes()
311
+ except RuntimeError as e:
312
+ print(f"Error during TTS inference: {e}")
313
+ if "device-side assert" in str(e) and api:
314
+ try:
315
+ gr.Warning("Critical GPU error. Attempting to restart the Space...")
316
+ api.restart_space(repo_id=repo_id)
317
+ except Exception:
318
+ pass
319
 
320
  # ===================================================================================
321
  # 5) ZERO-GPU ENTRYPOINT
322
  # ===================================================================================
323
 
324
+ @spaces.GPU(duration=120) # Request GPU for up to 120s; adjust as needed
325
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
326
  if secret_token_input != SECRET_TOKEN:
327
  raise gr.Error("Invalid secret token provided.")
328
  if not input_text:
329
  return []
330
 
331
+ # Ensure models/latents exist
332
+ if tts_model is None or llm_model is None or not voice_latents:
333
+ init_models_and_latents()
334
 
335
+ # Move XTTS to CUDA for this call if GPU is available; otherwise stay on CPU
336
  try:
337
  if torch.cuda.is_available():
338
  tts_model.to("cuda")
339
+ else:
340
+ tts_model.to("cpu")
341
+ except Exception:
342
+ tts_model.to("cpu")
343
+
344
+ # Generate story text (LLM on CPU)
345
+ history: List[Tuple[str, str | None]] = [(input_text, None)]
346
+ full_story_text = "".join(
347
+ generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
348
+ ).strip()
349
+ if not full_story_text:
350
+ return []
351
 
352
+ # Split into TTS-friendly sentences
353
+ sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
354
+ lang = langid.classify(sentences[0])[0] if sentences else "en"
 
 
 
 
355
 
356
+ results: List[Dict[str, str]] = []
357
+ for sentence in sentences:
358
+ if not any(c.isalnum() for c in sentence):
359
+ continue
360
 
361
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
362
+ pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
 
 
 
 
363
 
364
+ # Optional noise reduction (best-effort)
365
+ try:
366
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
367
  if data_s16.size > 0:
368
  float_data = data_s16.astype(np.float32) / 32767.0
 
370
  final_pcm = (reduced * 32767).astype(np.int16).tobytes()
371
  else:
372
  final_pcm = pcm_data
373
+ except Exception:
374
+ final_pcm = pcm_data
375
+
376
+ b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
377
+ results.append({"text": sentence, "audio": b64_wav})
378
+
379
+ # Release GPU immediately
380
+ try:
381
+ tts_model.to("cpu")
382
+ except Exception:
383
+ pass
384
+
385
+ return results
386
 
387
  # ===================================================================================
388
  # 6) STARTUP: PRECACHE & UI
389
  # ===================================================================================
390
 
391
+ def build_ui() -> gr.Interface:
392
+ return gr.Interface(
393
+ fn=generate_story_and_speech,
394
+ inputs=[
395
+ gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN),
396
+ gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"),
397
+ gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"),
398
+ ],
399
+ outputs=gr.JSON(label="Story and Audio Output"),
400
+ title="AI Storyteller with ZeroGPU",
401
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
402
+ flagging_mode="never",
403
+ )
 
 
 
 
 
 
404
 
405
  if __name__ == "__main__":
406
  print("===== Startup: pre-cache assets and preload models =====")
407
+ precache_assets() # 1) download everything to disk
408
+ init_models_and_latents() # 2) load models on CPU + compute voice latents (via patched loader)
409
  print("Models and assets ready. Launching UI...")
410
 
411
  demo = build_ui()
412
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))