ruslanmv commited on
Commit
2b38b16
·
1 Parent(s): 427d75a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -144
app.py CHANGED
@@ -14,11 +14,7 @@ from typing import List, Dict, Tuple, Generator
14
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
18
-
19
- # >>> CRITICAL: force torchaudio to avoid FFmpeg/torio path <<<
20
- # Must be set BEFORE importing torchaudio
21
- os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0")
22
 
23
  # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
24
  from dotenv import load_dotenv
@@ -26,8 +22,8 @@ load_dotenv()
26
 
27
  # --- Hugging Face Spaces & ZeroGPU ---
28
  try:
29
- import spaces # Required for ZeroGPU on HF
30
- except Exception:
31
  class _SpacesShim:
32
  def GPU(self, *args, **kwargs):
33
  def _wrap(fn):
@@ -42,12 +38,8 @@ import torch
42
  import numpy as np
43
  from huggingface_hub import HfApi, hf_hub_download
44
  from llama_cpp import Llama
45
-
46
- # --- torchaudio (dispatcher: FFmpeg disabled via env above) ---
47
- try:
48
- import torchaudio # noqa: F401
49
- except Exception:
50
- torchaudio = None # XTTS will still call torchaudio.load internally; env disables ffmpeg path
51
 
52
  # --- TTS Libraries ---
53
  from TTS.tts.configs.xtts_config import XttsConfig
@@ -71,7 +63,7 @@ nltk.download("punkt", quiet=True)
71
  # Cached models & latents
72
  tts_model: Xtts | None = None
73
  llm_model: Llama | None = None
74
- voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
75
 
76
  # Config
77
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -133,7 +125,6 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
133
 
134
  def precache_assets() -> None:
135
  """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
136
- # Voices
137
  print("Pre-caching voice files...")
138
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
139
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
@@ -149,39 +140,31 @@ def precache_assets() -> None:
149
  except Exception as e:
150
  print(f"Failed to download {name}: {e}")
151
 
152
- # XTTS model files
153
  print("Pre-caching XTTS v2 model files...")
154
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
155
 
156
- # LLM GGUF
157
  print("Pre-caching Zephyr GGUF...")
158
  try:
159
  hf_hub_download(
160
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
161
  filename="zephyr-7b-beta.Q5_K_M.gguf",
162
- force_download=False
163
  )
164
  except Exception as e:
165
  print(f"Warning: GGUF pre-cache error: {e}")
166
 
167
  def _load_xtts(device: str) -> Xtts:
168
- """Load XTTS from the local cache. Use checkpoint_dir to avoid None path bug."""
169
  print("Loading Coqui XTTS V2 model (CPU first)...")
170
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
171
- ModelManager().download_model(model_name) # idempotent
172
  model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
 
 
173
 
174
  cfg = XttsConfig()
175
  cfg.load_json(os.path.join(model_dir, "config.json"))
176
  model = Xtts.init_from_config(cfg)
177
-
178
- # IMPORTANT: use checkpoint_dir (fixes speakers file path resolution)
179
- model.load_checkpoint(
180
- cfg,
181
- checkpoint_dir=model_dir,
182
- eval=True,
183
- use_deepspeed=False, # deepspeed not installed in Spaces
184
- )
185
  model.to(device)
186
  print("XTTS model loaded.")
187
  return model
@@ -195,143 +178,125 @@ def _load_llama() -> Llama:
195
  )
196
  llm = Llama(
197
  model_path=zephyr_model_path,
198
- n_gpu_layers=0, # CPU by default to keep it ready without GPU
199
- n_ctx=4096,
200
- n_batch=512,
201
- verbose=False
202
  )
203
  print("LLM loaded (CPU).")
204
  return llm
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def init_models_and_latents() -> None:
207
  """Preload TTS and LLM on CPU and compute voice latents once."""
208
  global tts_model, llm_model, voice_latents
209
 
210
  if tts_model is None:
211
- tts_model = _load_xtts(device="cpu") # keep on CPU at startup
212
 
213
  if llm_model is None:
214
  llm_model = _load_llama()
215
 
216
- # Pre-compute latents once (CPU OK)
217
  if not voice_latents:
218
  print("Computing voice conditioning latents...")
219
- for role, filename in [
220
- ("Cloée", "cloee-1.wav"),
221
- ("Julian", "julian-bedtime-style-1.wav"),
222
- ("Pirate", "pirate_by_coqui.wav"),
223
- ("Thera", "thera-1.wav"),
224
- ]:
225
  path = os.path.join("voices", filename)
 
 
226
  voice_latents[role] = tts_model.get_conditioning_latents(
227
- audio_path=path, gpt_cond_len=30, max_ref_length=60
228
  )
229
  print("Voice latents ready.")
230
 
231
- # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
232
  def _close_llm():
233
  global llm_model
234
- try:
235
- if llm_model is not None:
236
- llm_model.close()
237
- except Exception:
238
- pass
239
  atexit.register(_close_llm)
240
 
241
  # ===================================================================================
242
  # 4) INFERENCE HELPERS
243
  # ===================================================================================
244
 
245
- def generate_text_stream(llm_instance: Llama, prompt: str,
246
- history: List[Tuple[str, str | None]],
247
- system_message_text: str) -> Generator[str, None, None]:
248
- formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
249
  stream = llm_instance(
250
- formatted_prompt,
251
- temperature=0.7,
252
- max_tokens=512,
253
- top_p=0.95,
254
- stop=LLM_STOP_WORDS,
255
- stream=True
256
  )
257
  for response in stream:
258
- ch = response["choices"][0]["text"]
259
- try:
260
- is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
261
- except Exception:
262
- is_single_emoji = False
263
- if "<|user|>" in ch or is_single_emoji:
264
- continue
265
- yield ch
266
-
267
- def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
268
- latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
269
  gpt_cond_latent, speaker_embedding = latents
270
- try:
271
- for chunk in tts_instance.inference_stream(
272
- text=text,
273
- language=language,
274
- gpt_cond_latent=gpt_cond_latent,
275
- speaker_embedding=speaker_embedding,
276
- temperature=0.85,
277
- ):
278
- if chunk is not None:
279
- yield chunk.detach().cpu().numpy().squeeze().tobytes()
280
- except RuntimeError as e:
281
- print(f"Error during TTS inference: {e}")
282
- if "device-side assert" in str(e) and api:
283
- gr.Warning("Critical GPU error. Attempting to restart the Space...")
284
- try:
285
- api.restart_space(repo_id=repo_id)
286
- except Exception:
287
- pass
288
 
289
  # ===================================================================================
290
  # 5) ZERO-GPU ENTRYPOINT
291
  # ===================================================================================
292
 
293
- @spaces.GPU(duration=120) # Request GPU for 120s (tune as needed)
294
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
295
  if secret_token_input != SECRET_TOKEN:
296
  raise gr.Error("Invalid secret token provided.")
297
  if not input_text:
298
  return []
299
 
300
- # Models & latents are preloaded at startup; ensure available
301
- if tts_model is None or llm_model is None or not voice_latents:
302
- init_models_and_latents()
303
 
304
- # If ZeroGPU provided a GPU for this call, move XTTS to CUDA for faster audio
305
  try:
306
  if torch.cuda.is_available():
307
  tts_model.to("cuda")
308
- else:
309
- tts_model.to("cpu")
310
- except Exception:
311
- tts_model.to("cpu")
312
-
313
- # Generate story text (LLM runs on CPU, doesn't need ZeroGPU)
314
- history: List[Tuple[str, str | None]] = [(input_text, None)]
315
- full_story_text = "".join(
316
- generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
317
- ).strip()
318
- if not full_story_text:
319
- return []
320
 
321
- # Tokenize into shorter sentences for TTS
322
- sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
323
- lang = langid.classify(sentences[0])[0] if sentences else "en"
 
324
 
325
- results: List[Dict[str, str]] = []
326
- for sentence in sentences:
327
- if not any(c.isalnum() for c in sentence):
328
- continue
329
 
330
- audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
331
- pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
 
 
 
 
 
 
 
 
332
 
333
- # Optional noise reduction (best-effort)
334
- try:
335
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
336
  if data_s16.size > 0:
337
  float_data = data_s16.astype(np.float32) / 32767.0
@@ -339,44 +304,46 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
339
  final_pcm = (reduced * 32767).astype(np.int16).tobytes()
340
  else:
341
  final_pcm = pcm_data
342
- except Exception:
343
- final_pcm = pcm_data
344
-
345
- b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
346
- results.append({"text": sentence, "audio": b64_wav})
347
-
348
- # Return XTTS to CPU to free GPU instantly after the call
349
- try:
350
- tts_model.to("cpu")
351
- except Exception:
352
- pass
353
-
354
- return results
355
 
356
  # ===================================================================================
357
  # 6) STARTUP: PRECACHE & UI
358
  # ===================================================================================
359
 
360
- def build_ui() -> gr.Interface:
361
- return gr.Interface(
362
- fn=generate_story_and_speech,
363
- inputs=[
364
- gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN),
365
- gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"),
366
- gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"),
367
- ],
368
- outputs=gr.JSON(label="Story and Audio Output"),
369
- title="AI Storyteller with ZeroGPU",
370
- description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
371
- allow_flagging="never", # Gradio 3.50.2
372
- analytics_enabled=False, # keep analytics fully disabled (pairs with env var)
373
- )
 
 
 
 
 
374
 
375
  if __name__ == "__main__":
376
  print("===== Startup: pre-cache assets and preload models =====")
377
- precache_assets() # 1) download everything to disk
378
- init_models_and_latents() # 2) load models on CPU + compute voice latents
379
  print("Models and assets ready. Launching UI...")
380
 
381
  demo = build_ui()
382
- demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
14
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
  os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
 
 
 
 
18
 
19
  # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
20
  from dotenv import load_dotenv
 
22
 
23
  # --- Hugging Face Spaces & ZeroGPU ---
24
  try:
25
+ import spaces
26
+ except ImportError:
27
  class _SpacesShim:
28
  def GPU(self, *args, **kwargs):
29
  def _wrap(fn):
 
38
  import numpy as np
39
  from huggingface_hub import HfApi, hf_hub_download
40
  from llama_cpp import Llama
41
+ import torchaudio # Still needed for transforms, just not loading
42
+ import soundfile as sf # <-- FIX: Import soundfile for robust audio loading
 
 
 
 
43
 
44
  # --- TTS Libraries ---
45
  from TTS.tts.configs.xtts_config import XttsConfig
 
63
  # Cached models & latents
64
  tts_model: Xtts | None = None
65
  llm_model: Llama | None = None
66
+ voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
67
 
68
  # Config
69
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
125
 
126
  def precache_assets() -> None:
127
  """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
 
128
  print("Pre-caching voice files...")
129
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
130
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
 
140
  except Exception as e:
141
  print(f"Failed to download {name}: {e}")
142
 
 
143
  print("Pre-caching XTTS v2 model files...")
144
  ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
145
 
 
146
  print("Pre-caching Zephyr GGUF...")
147
  try:
148
  hf_hub_download(
149
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
150
  filename="zephyr-7b-beta.Q5_K_M.gguf",
151
+ local_dir_use_symlinks=False,
152
  )
153
  except Exception as e:
154
  print(f"Warning: GGUF pre-cache error: {e}")
155
 
156
  def _load_xtts(device: str) -> Xtts:
157
+ """Load XTTS from the local cache."""
158
  print("Loading Coqui XTTS V2 model (CPU first)...")
159
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
 
160
  model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
161
+ if not os.path.exists(model_dir):
162
+ ModelManager().download_model(model_name)
163
 
164
  cfg = XttsConfig()
165
  cfg.load_json(os.path.join(model_dir, "config.json"))
166
  model = Xtts.init_from_config(cfg)
167
+ model.load_checkpoint(cfg, checkpoint_dir=model_dir, eval=True, use_deepspeed=False)
 
 
 
 
 
 
 
168
  model.to(device)
169
  print("XTTS model loaded.")
170
  return model
 
178
  )
179
  llm = Llama(
180
  model_path=zephyr_model_path,
181
+ n_gpu_layers=0, n_ctx=4096, n_batch=512, verbose=False
 
 
 
182
  )
183
  print("LLM loaded (CPU).")
184
  return llm
185
 
186
+ # --- FIX: Replaced torchaudio.load with soundfile.read to fix RuntimeError ---
187
+ def load_audio_for_tts(path: str, target_sr: int = 24000) -> torch.Tensor:
188
+ """Loads audio using soundfile, converts to a Torch tensor, and resamples if needed."""
189
+ try:
190
+ # Read audio file into a NumPy array
191
+ audio_np, original_sr = sf.read(path, dtype='float32')
192
+
193
+ # Ensure it's mono
194
+ if audio_np.ndim > 1:
195
+ audio_np = np.mean(audio_np, axis=1)
196
+
197
+ # Convert to a PyTorch tensor
198
+ waveform = torch.from_numpy(audio_np).float()
199
+
200
+ # Resample if the sample rate is not the target rate
201
+ if original_sr != target_sr:
202
+ print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.")
203
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
204
+ waveform = resampler(waveform)
205
+
206
+ return waveform.unsqueeze(0) # Add batch dimension: shape (1, T)
207
+ except Exception as e:
208
+ print(f"Error loading audio file {path}: {e}")
209
+ raise
210
+
211
  def init_models_and_latents() -> None:
212
  """Preload TTS and LLM on CPU and compute voice latents once."""
213
  global tts_model, llm_model, voice_latents
214
 
215
  if tts_model is None:
216
+ tts_model = _load_xtts(device="cpu")
217
 
218
  if llm_model is None:
219
  llm_model = _load_llama()
220
 
 
221
  if not voice_latents:
222
  print("Computing voice conditioning latents...")
223
+ voice_files = {
224
+ "Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav",
225
+ "Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav",
226
+ }
227
+ for role, filename in voice_files.items():
 
228
  path = os.path.join("voices", filename)
229
+ # Load audio externally and pass the waveform tensor directly
230
+ waveform = load_audio_for_tts(path)
231
  voice_latents[role] = tts_model.get_conditioning_latents(
232
+ waveform=waveform, gpt_cond_len=30, max_ref_length=60
233
  )
234
  print("Voice latents ready.")
235
 
 
236
  def _close_llm():
237
  global llm_model
238
+ if llm_model is not None:
239
+ del llm_model
 
 
 
240
  atexit.register(_close_llm)
241
 
242
  # ===================================================================================
243
  # 4) INFERENCE HELPERS
244
  # ===================================================================================
245
 
246
+ def generate_text_stream(llm_instance: Llama, prompt: str, history: List, sys_prompt: str) -> Generator[str, None, None]:
247
+ formatted_prompt = format_prompt_zephyr(prompt, history, sys_prompt)
 
 
248
  stream = llm_instance(
249
+ formatted_prompt, temperature=0.7, max_tokens=512, top_p=0.95, stop=LLM_STOP_WORDS, stream=True
 
 
 
 
 
250
  )
251
  for response in stream:
252
+ yield response["choices"][0]["text"]
253
+
254
+ def generate_audio_stream(tts_instance: Xtts, text: str, lang: str, latents: Tuple) -> Generator[bytes, None, None]:
 
 
 
 
 
 
 
 
255
  gpt_cond_latent, speaker_embedding = latents
256
+ for chunk in tts_instance.inference_stream(
257
+ text, lang, gpt_cond_latent, speaker_embedding, temperature=0.85,
258
+ ):
259
+ if chunk is not None:
260
+ yield chunk.detach().cpu().numpy().squeeze().tobytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  # ===================================================================================
263
  # 5) ZERO-GPU ENTRYPOINT
264
  # ===================================================================================
265
 
266
+ @spaces.GPU(duration=120)
267
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
268
  if secret_token_input != SECRET_TOKEN:
269
  raise gr.Error("Invalid secret token provided.")
270
  if not input_text:
271
  return []
272
 
273
+ # Models must be preloaded, this is a fallback.
274
+ if tts_model is None or llm_model is None:
275
+ raise gr.Error("Models not initialized. Please restart the Space.")
276
 
 
277
  try:
278
  if torch.cuda.is_available():
279
  tts_model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ history: List[Tuple[str, str | None]] = [(input_text, None)]
282
+ full_story_text = "".join(
283
+ generate_text_stream(llm_model, history[-1][0], history[:-1], ROLE_PROMPTS[chatbot_role])
284
+ ).strip()
285
 
286
+ if not full_story_text:
287
+ return []
 
 
288
 
289
+ sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
290
+ lang = langid.classify(sentences[0])[0] if sentences else "en"
291
+ results: List[Dict[str, str]] = []
292
+
293
+ for sentence in sentences:
294
+ if not any(c.isalnum() for c in sentence):
295
+ continue
296
+
297
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
298
+ pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
299
 
 
 
300
  data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
301
  if data_s16.size > 0:
302
  float_data = data_s16.astype(np.float32) / 32767.0
 
304
  final_pcm = (reduced * 32767).astype(np.int16).tobytes()
305
  else:
306
  final_pcm = pcm_data
307
+
308
+ b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
309
+ results.append({"text": sentence, "audio": b64_wav})
310
+
311
+ return results
312
+
313
+ finally:
314
+ # Crucial for ZeroGPU: ensure model returns to CPU to free the GPU
315
+ if tts_model is not None:
316
+ tts_model.to("cpu")
 
 
 
317
 
318
  # ===================================================================================
319
  # 6) STARTUP: PRECACHE & UI
320
  # ===================================================================================
321
 
322
+ def build_ui() -> gr.Blocks:
323
+ with gr.Blocks() as demo:
324
+ gr.Markdown("# AI Storyteller with ZeroGPU")
325
+ gr.Markdown("Enter a prompt to generate a short story with voice narration using on-demand GPU.")
326
+
327
+ with gr.Row():
328
+ secret_token = gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN)
329
+ storyteller = gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée")
330
+
331
+ prompt = gr.Textbox(placeholder="What should the story be about?", label="Story Prompt")
332
+ output = gr.JSON(label="Story and Audio Output")
333
+
334
+ prompt.submit(
335
+ fn=generate_story_and_speech,
336
+ inputs=[secret_token, prompt, storyteller],
337
+ outputs=output,
338
+ )
339
+
340
+ return demo
341
 
342
  if __name__ == "__main__":
343
  print("===== Startup: pre-cache assets and preload models =====")
344
+ precache_assets()
345
+ init_models_and_latents()
346
  print("Models and assets ready. Launching UI...")
347
 
348
  demo = build_ui()
349
+ demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))