ruslanmv commited on
Commit
427d75a
·
1 Parent(s): a28e45a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -103
app.py CHANGED
@@ -14,7 +14,11 @@ from typing import List, Dict, Tuple, Generator
14
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
 
 
 
 
18
 
19
  # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
20
  from dotenv import load_dotenv
@@ -22,8 +26,8 @@ load_dotenv()
22
 
23
  # --- Hugging Face Spaces & ZeroGPU ---
24
  try:
25
- import spaces
26
- except ImportError:
27
  class _SpacesShim:
28
  def GPU(self, *args, **kwargs):
29
  def _wrap(fn):
@@ -39,20 +43,11 @@ import numpy as np
39
  from huggingface_hub import HfApi, hf_hub_download
40
  from llama_cpp import Llama
41
 
42
- # --- Set torchaudio backend BEFORE it's used ---
43
- # This attempts to use soundfile or sox_io before falling back, avoiding the buggy ffmpeg backend.
44
  try:
45
- import torchaudio
46
- # Try to set a more stable backend
47
- for backend in ("soundfile", "sox_io"):
48
- try:
49
- torchaudio.set_audio_backend(backend)
50
- print(f"Torchaudio backend set to: {backend}")
51
- break
52
- except Exception:
53
- continue
54
  except Exception:
55
- print("Could not import or set torchaudio backend.")
56
 
57
  # --- TTS Libraries ---
58
  from TTS.tts.configs.xtts_config import XttsConfig
@@ -76,7 +71,7 @@ nltk.download("punkt", quiet=True)
76
  # Cached models & latents
77
  tts_model: Xtts | None = None
78
  llm_model: Llama | None = None
79
- voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
80
 
81
  # Config
82
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -138,6 +133,7 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
138
 
139
  def precache_assets() -> None:
140
  """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
 
141
  print("Pre-caching voice files...")
142
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
143
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
@@ -153,31 +149,39 @@ def precache_assets() -> None:
153
  except Exception as e:
154
  print(f"Failed to download {name}: {e}")
155
 
 
156
  print("Pre-caching XTTS v2 model files...")
157
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
158
 
 
159
  print("Pre-caching Zephyr GGUF...")
160
  try:
161
  hf_hub_download(
162
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
163
  filename="zephyr-7b-beta.Q5_K_M.gguf",
164
- local_dir_use_symlinks=False,
165
  )
166
  except Exception as e:
167
  print(f"Warning: GGUF pre-cache error: {e}")
168
 
169
  def _load_xtts(device: str) -> Xtts:
170
- """Load XTTS from the local cache."""
171
  print("Loading Coqui XTTS V2 model (CPU first)...")
172
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
 
173
  model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
174
- if not os.path.exists(model_dir):
175
- ModelManager().download_model(model_name)
176
 
177
  cfg = XttsConfig()
178
  cfg.load_json(os.path.join(model_dir, "config.json"))
179
  model = Xtts.init_from_config(cfg)
180
- model.load_checkpoint(cfg, checkpoint_dir=model_dir, eval=True, use_deepspeed=False)
 
 
 
 
 
 
 
181
  model.to(device)
182
  print("XTTS model loaded.")
183
  return model
@@ -191,108 +195,143 @@ def _load_llama() -> Llama:
191
  )
192
  llm = Llama(
193
  model_path=zephyr_model_path,
194
- n_gpu_layers=0, n_ctx=4096, n_batch=512, verbose=False
 
 
 
195
  )
196
  print("LLM loaded (CPU).")
197
  return llm
198
 
199
- def load_audio_for_tts(path: str, target_sr: int = 24000) -> torch.Tensor:
200
- """Loads and resamples audio, returning a Torch tensor to avoid TTS internal loading."""
201
- waveform, sr = torchaudio.load(path)
202
- if sr != target_sr:
203
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
204
- waveform = resampler(waveform)
205
- return waveform.view(1, -1) # Ensure shape is (1, T) for TTS model
206
-
207
  def init_models_and_latents() -> None:
208
  """Preload TTS and LLM on CPU and compute voice latents once."""
209
  global tts_model, llm_model, voice_latents
210
 
211
  if tts_model is None:
212
- tts_model = _load_xtts(device="cpu")
213
 
214
  if llm_model is None:
215
  llm_model = _load_llama()
216
 
 
217
  if not voice_latents:
218
  print("Computing voice conditioning latents...")
219
- voice_files = {
220
- "Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav",
221
- "Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav",
222
- }
223
- for role, filename in voice_files.items():
 
224
  path = os.path.join("voices", filename)
225
- # --- FIX: Load audio externally and pass the waveform tensor directly ---
226
- waveform = load_audio_for_tts(path)
227
  voice_latents[role] = tts_model.get_conditioning_latents(
228
- waveform=waveform, gpt_cond_len=30, max_ref_length=60
229
  )
230
  print("Voice latents ready.")
231
 
 
232
  def _close_llm():
233
  global llm_model
234
- if llm_model is not None:
235
- del llm_model
 
 
 
236
  atexit.register(_close_llm)
237
 
238
  # ===================================================================================
239
  # 4) INFERENCE HELPERS
240
  # ===================================================================================
241
 
242
- def generate_text_stream(llm_instance: Llama, prompt: str, history: List, sys_prompt: str) -> Generator[str, None, None]:
243
- formatted_prompt = format_prompt_zephyr(prompt, history, sys_prompt)
 
 
244
  stream = llm_instance(
245
- formatted_prompt, temperature=0.7, max_tokens=512, top_p=0.95, stop=LLM_STOP_WORDS, stream=True
 
 
 
 
 
246
  )
247
  for response in stream:
248
- yield response["choices"][0]["text"]
 
 
 
 
 
 
 
249
 
250
- def generate_audio_stream(tts_instance: Xtts, text: str, lang: str, latents: Tuple) -> Generator[bytes, None, None]:
 
251
  gpt_cond_latent, speaker_embedding = latents
252
- for chunk in tts_instance.inference_stream(
253
- text, lang, gpt_cond_latent, speaker_embedding, temperature=0.85,
254
- ):
255
- if chunk is not None:
256
- yield chunk.detach().cpu().numpy().squeeze().tobytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  # ===================================================================================
259
  # 5) ZERO-GPU ENTRYPOINT
260
  # ===================================================================================
261
 
262
- @spaces.GPU(duration=120)
263
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
264
  if secret_token_input != SECRET_TOKEN:
265
  raise gr.Error("Invalid secret token provided.")
266
  if not input_text:
267
  return []
268
 
269
- # Models must be preloaded, this is a fallback.
270
- if tts_model is None or llm_model is None:
271
- raise gr.Error("Models not initialized. Please restart the Space.")
272
 
 
273
  try:
274
  if torch.cuda.is_available():
275
  tts_model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- history: List[Tuple[str, str | None]] = [(input_text, None)]
278
- full_story_text = "".join(
279
- generate_text_stream(llm_model, history[-1][0], history[:-1], ROLE_PROMPTS[chatbot_role])
280
- ).strip()
281
-
282
- if not full_story_text:
283
- return []
284
 
285
- sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
286
- lang = langid.classify(sentences[0])[0] if sentences else "en"
287
- results: List[Dict[str, str]] = []
 
288
 
289
- for sentence in sentences:
290
- if not any(c.isalnum() for c in sentence):
291
- continue
292
-
293
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
294
- pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
295
 
 
 
296
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
297
  if data_s16.size > 0:
298
  float_data = data_s16.astype(np.float32) / 32767.0
@@ -300,46 +339,44 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
300
  final_pcm = (reduced * 32767).astype(np.int16).tobytes()
301
  else:
302
  final_pcm = pcm_data
303
-
304
- b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
305
- results.append({"text": sentence, "audio": b64_wav})
306
-
307
- return results
308
-
309
- finally:
310
- # Crucial for ZeroGPU: ensure model returns to CPU to free the GPU
311
- if tts_model is not None:
312
- tts_model.to("cpu")
 
 
 
313
 
314
  # ===================================================================================
315
  # 6) STARTUP: PRECACHE & UI
316
  # ===================================================================================
317
 
318
- def build_ui() -> gr.Blocks:
319
- with gr.Blocks() as demo:
320
- gr.Markdown("# AI Storyteller with ZeroGPU")
321
- gr.Markdown("Enter a prompt to generate a short story with voice narration using on-demand GPU.")
322
-
323
- with gr.Row():
324
- secret_token = gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN)
325
- storyteller = gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée")
326
-
327
- prompt = gr.Textbox(placeholder="What should the story be about?", label="Story Prompt")
328
- output = gr.JSON(label="Story and Audio Output")
329
-
330
- prompt.submit(
331
- fn=generate_story_and_speech,
332
- inputs=[secret_token, prompt, storyteller],
333
- outputs=output,
334
- )
335
-
336
- return demo
337
 
338
  if __name__ == "__main__":
339
  print("===== Startup: pre-cache assets and preload models =====")
340
- precache_assets()
341
- init_models_and_latents()
342
  print("Models and assets ready. Launching UI...")
343
 
344
  demo = build_ui()
345
- demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
14
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
18
+
19
+ # >>> CRITICAL: force torchaudio to avoid FFmpeg/torio path <<<
20
+ # Must be set BEFORE importing torchaudio
21
+ os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0")
22
 
23
  # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
24
  from dotenv import load_dotenv
 
26
 
27
  # --- Hugging Face Spaces & ZeroGPU ---
28
  try:
29
+ import spaces # Required for ZeroGPU on HF
30
+ except Exception:
31
  class _SpacesShim:
32
  def GPU(self, *args, **kwargs):
33
  def _wrap(fn):
 
43
  from huggingface_hub import HfApi, hf_hub_download
44
  from llama_cpp import Llama
45
 
46
+ # --- torchaudio (dispatcher: FFmpeg disabled via env above) ---
 
47
  try:
48
+ import torchaudio # noqa: F401
 
 
 
 
 
 
 
 
49
  except Exception:
50
+ torchaudio = None # XTTS will still call torchaudio.load internally; env disables ffmpeg path
51
 
52
  # --- TTS Libraries ---
53
  from TTS.tts.configs.xtts_config import XttsConfig
 
71
  # Cached models & latents
72
  tts_model: Xtts | None = None
73
  llm_model: Llama | None = None
74
+ voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
75
 
76
  # Config
77
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
133
 
134
  def precache_assets() -> None:
135
  """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
136
+ # Voices
137
  print("Pre-caching voice files...")
138
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
139
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
 
149
  except Exception as e:
150
  print(f"Failed to download {name}: {e}")
151
 
152
+ # XTTS model files
153
  print("Pre-caching XTTS v2 model files...")
154
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
155
 
156
+ # LLM GGUF
157
  print("Pre-caching Zephyr GGUF...")
158
  try:
159
  hf_hub_download(
160
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
161
  filename="zephyr-7b-beta.Q5_K_M.gguf",
162
+ force_download=False
163
  )
164
  except Exception as e:
165
  print(f"Warning: GGUF pre-cache error: {e}")
166
 
167
  def _load_xtts(device: str) -> Xtts:
168
+ """Load XTTS from the local cache. Use checkpoint_dir to avoid None path bug."""
169
  print("Loading Coqui XTTS V2 model (CPU first)...")
170
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
171
+ ModelManager().download_model(model_name) # idempotent
172
  model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
 
 
173
 
174
  cfg = XttsConfig()
175
  cfg.load_json(os.path.join(model_dir, "config.json"))
176
  model = Xtts.init_from_config(cfg)
177
+
178
+ # IMPORTANT: use checkpoint_dir (fixes speakers file path resolution)
179
+ model.load_checkpoint(
180
+ cfg,
181
+ checkpoint_dir=model_dir,
182
+ eval=True,
183
+ use_deepspeed=False, # deepspeed not installed in Spaces
184
+ )
185
  model.to(device)
186
  print("XTTS model loaded.")
187
  return model
 
195
  )
196
  llm = Llama(
197
  model_path=zephyr_model_path,
198
+ n_gpu_layers=0, # CPU by default to keep it ready without GPU
199
+ n_ctx=4096,
200
+ n_batch=512,
201
+ verbose=False
202
  )
203
  print("LLM loaded (CPU).")
204
  return llm
205
 
 
 
 
 
 
 
 
 
206
  def init_models_and_latents() -> None:
207
  """Preload TTS and LLM on CPU and compute voice latents once."""
208
  global tts_model, llm_model, voice_latents
209
 
210
  if tts_model is None:
211
+ tts_model = _load_xtts(device="cpu") # keep on CPU at startup
212
 
213
  if llm_model is None:
214
  llm_model = _load_llama()
215
 
216
+ # Pre-compute latents once (CPU OK)
217
  if not voice_latents:
218
  print("Computing voice conditioning latents...")
219
+ for role, filename in [
220
+ ("Cloée", "cloee-1.wav"),
221
+ ("Julian", "julian-bedtime-style-1.wav"),
222
+ ("Pirate", "pirate_by_coqui.wav"),
223
+ ("Thera", "thera-1.wav"),
224
+ ]:
225
  path = os.path.join("voices", filename)
 
 
226
  voice_latents[role] = tts_model.get_conditioning_latents(
227
+ audio_path=path, gpt_cond_len=30, max_ref_length=60
228
  )
229
  print("Voice latents ready.")
230
 
231
+ # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
232
  def _close_llm():
233
  global llm_model
234
+ try:
235
+ if llm_model is not None:
236
+ llm_model.close()
237
+ except Exception:
238
+ pass
239
  atexit.register(_close_llm)
240
 
241
  # ===================================================================================
242
  # 4) INFERENCE HELPERS
243
  # ===================================================================================
244
 
245
+ def generate_text_stream(llm_instance: Llama, prompt: str,
246
+ history: List[Tuple[str, str | None]],
247
+ system_message_text: str) -> Generator[str, None, None]:
248
+ formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
249
  stream = llm_instance(
250
+ formatted_prompt,
251
+ temperature=0.7,
252
+ max_tokens=512,
253
+ top_p=0.95,
254
+ stop=LLM_STOP_WORDS,
255
+ stream=True
256
  )
257
  for response in stream:
258
+ ch = response["choices"][0]["text"]
259
+ try:
260
+ is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
261
+ except Exception:
262
+ is_single_emoji = False
263
+ if "<|user|>" in ch or is_single_emoji:
264
+ continue
265
+ yield ch
266
 
267
+ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
268
+ latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
269
  gpt_cond_latent, speaker_embedding = latents
270
+ try:
271
+ for chunk in tts_instance.inference_stream(
272
+ text=text,
273
+ language=language,
274
+ gpt_cond_latent=gpt_cond_latent,
275
+ speaker_embedding=speaker_embedding,
276
+ temperature=0.85,
277
+ ):
278
+ if chunk is not None:
279
+ yield chunk.detach().cpu().numpy().squeeze().tobytes()
280
+ except RuntimeError as e:
281
+ print(f"Error during TTS inference: {e}")
282
+ if "device-side assert" in str(e) and api:
283
+ gr.Warning("Critical GPU error. Attempting to restart the Space...")
284
+ try:
285
+ api.restart_space(repo_id=repo_id)
286
+ except Exception:
287
+ pass
288
 
289
  # ===================================================================================
290
  # 5) ZERO-GPU ENTRYPOINT
291
  # ===================================================================================
292
 
293
+ @spaces.GPU(duration=120) # Request GPU for 120s (tune as needed)
294
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
295
  if secret_token_input != SECRET_TOKEN:
296
  raise gr.Error("Invalid secret token provided.")
297
  if not input_text:
298
  return []
299
 
300
+ # Models & latents are preloaded at startup; ensure available
301
+ if tts_model is None or llm_model is None or not voice_latents:
302
+ init_models_and_latents()
303
 
304
+ # If ZeroGPU provided a GPU for this call, move XTTS to CUDA for faster audio
305
  try:
306
  if torch.cuda.is_available():
307
  tts_model.to("cuda")
308
+ else:
309
+ tts_model.to("cpu")
310
+ except Exception:
311
+ tts_model.to("cpu")
312
+
313
+ # Generate story text (LLM runs on CPU, doesn't need ZeroGPU)
314
+ history: List[Tuple[str, str | None]] = [(input_text, None)]
315
+ full_story_text = "".join(
316
+ generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
317
+ ).strip()
318
+ if not full_story_text:
319
+ return []
320
 
321
+ # Tokenize into shorter sentences for TTS
322
+ sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
323
+ lang = langid.classify(sentences[0])[0] if sentences else "en"
 
 
 
 
324
 
325
+ results: List[Dict[str, str]] = []
326
+ for sentence in sentences:
327
+ if not any(c.isalnum() for c in sentence):
328
+ continue
329
 
330
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
331
+ pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
 
 
 
 
332
 
333
+ # Optional noise reduction (best-effort)
334
+ try:
335
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
336
  if data_s16.size > 0:
337
  float_data = data_s16.astype(np.float32) / 32767.0
 
339
  final_pcm = (reduced * 32767).astype(np.int16).tobytes()
340
  else:
341
  final_pcm = pcm_data
342
+ except Exception:
343
+ final_pcm = pcm_data
344
+
345
+ b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
346
+ results.append({"text": sentence, "audio": b64_wav})
347
+
348
+ # Return XTTS to CPU to free GPU instantly after the call
349
+ try:
350
+ tts_model.to("cpu")
351
+ except Exception:
352
+ pass
353
+
354
+ return results
355
 
356
  # ===================================================================================
357
  # 6) STARTUP: PRECACHE & UI
358
  # ===================================================================================
359
 
360
+ def build_ui() -> gr.Interface:
361
+ return gr.Interface(
362
+ fn=generate_story_and_speech,
363
+ inputs=[
364
+ gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN),
365
+ gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"),
366
+ gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"),
367
+ ],
368
+ outputs=gr.JSON(label="Story and Audio Output"),
369
+ title="AI Storyteller with ZeroGPU",
370
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
371
+ allow_flagging="never", # Gradio 3.50.2
372
+ analytics_enabled=False, # keep analytics fully disabled (pairs with env var)
373
+ )
 
 
 
 
 
374
 
375
  if __name__ == "__main__":
376
  print("===== Startup: pre-cache assets and preload models =====")
377
+ precache_assets() # 1) download everything to disk
378
+ init_models_and_latents() # 2) load models on CPU + compute voice latents
379
  print("Models and assets ready. Launching UI...")
380
 
381
  demo = build_ui()
382
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))