ruslanmv commited on
Commit
d662d9a
·
1 Parent(s): f3fa464

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -107
app.py CHANGED
@@ -1,20 +1,22 @@
1
  # ===================================================================================
2
- # 1. SETUP AND IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
  import os
6
- import requests
7
  import base64
8
  import struct
9
  import re
10
  import textwrap
11
- import uuid
12
  from typing import List, Dict, Tuple, Generator
13
 
14
- # Make sure Gradio analytics is off (so we don't need pandas 2.x)
15
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
 
 
 
16
 
17
- # --- Load .env early (for HF_TOKEN / SECRET_TOKEN) ---
18
  from dotenv import load_dotenv
19
  load_dotenv()
20
 
@@ -22,7 +24,6 @@ load_dotenv()
22
  try:
23
  import spaces # Required for ZeroGPU on HF
24
  except Exception:
25
- # Allow local runs without the spaces package
26
  class _SpacesShim:
27
  def GPU(self, *args, **kwargs):
28
  def _wrap(fn):
@@ -51,19 +52,18 @@ import emoji
51
  import noisereduce as nr
52
 
53
  # ===================================================================================
54
- # 2. GLOBAL CONFIGURATION & HELPER FUNCTIONS
55
  # ===================================================================================
56
 
57
- # Download NLTK data (punkt)
58
  nltk.download("punkt", quiet=True)
59
 
60
- os.environ["COQUI_TOS_AGREED"] = "1"
61
-
62
- # Cached models
63
  tts_model: Xtts | None = None
64
  llm_model: Llama | None = None
 
65
 
66
- # Configuration
67
  HF_TOKEN = os.environ.get("HF_TOKEN")
68
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
69
  repo_id = "ruslanmv/ai-story-server"
@@ -84,7 +84,7 @@ ROLE_PROMPTS["Pirate"] = (
84
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
85
  )
86
 
87
- # --- Audio helpers ---
88
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
89
  if pcm_data.startswith(b"RIFF"):
90
  return pcm_data
@@ -118,68 +118,117 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
118
  return prompt
119
 
120
  # ===================================================================================
121
- # 3. CORE AI FUNCTIONS (Model Loading & Inference)
122
  # ===================================================================================
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def _load_xtts(device: str) -> Xtts:
125
- print("Loading Coqui XTTS V2 model (first run)...")
 
126
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
127
- ModelManager().download_model(model_name)
128
- model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
 
 
 
 
129
 
130
- config = XttsConfig()
131
- config.load_json(os.path.join(model_path, "config.json"))
132
- model = Xtts.init_from_config(config)
133
- # NOTE: deepspeed not installed; keep False for Spaces
134
  model.load_checkpoint(
135
- config,
136
- checkpoint_path=os.path.join(model_path, "model.pth"),
137
- vocab_path=os.path.join(model_path, "vocab.json"),
138
  eval=True,
139
- use_deepspeed=False,
140
  )
141
  model.to(device)
142
  print("XTTS model loaded.")
143
  return model
144
 
145
  def _load_llama() -> Llama:
146
- print("Loading LLM (Zephyr) (first run)...")
 
147
  zephyr_model_path = hf_hub_download(
148
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
149
  filename="zephyr-7b-beta.Q5_K_M.gguf"
150
  )
151
- # Try GPU offload if available, else CPU
152
- for n_gpu_layers in (-1, 0):
153
- try:
154
- llm = Llama(
155
- model_path=zephyr_model_path,
156
- n_gpu_layers=n_gpu_layers,
157
- n_ctx=4096,
158
- n_batch=512,
159
- verbose=False
160
- )
161
- if n_gpu_layers == -1:
162
- print("LLM loaded with GPU offload.")
163
- else:
164
- print("LLM loaded (CPU).")
165
- return llm
166
- except Exception as e:
167
- print(f"LLM init with n_gpu_layers={n_gpu_layers} failed: {e}")
168
- raise RuntimeError("Failed to initialize Llama model.")
169
-
170
- def load_models() -> Tuple[Xtts, Llama]:
171
- global tts_model, llm_model
172
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
173
  if tts_model is None:
174
- tts_model = _load_xtts(device)
 
175
  if llm_model is None:
176
  llm_model = _load_llama()
177
- return tts_model, llm_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  def generate_text_stream(llm_instance: Llama, prompt: str,
180
  history: List[Tuple[str, str | None]],
181
- system_message: str) -> Generator[str, None, None]:
182
- formatted_prompt = format_prompt_zephyr(prompt, history, system_message)
183
  stream = llm_instance(
184
  formatted_prompt,
185
  temperature=0.7,
@@ -190,9 +239,8 @@ def generate_text_stream(llm_instance: Llama, prompt: str,
190
  )
191
  for response in stream:
192
  ch = response["choices"][0]["text"]
193
- # Guard against control tokens & isolated emoji artefacts
194
  try:
195
- is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch)) # emoji>=2.x
196
  except Exception:
197
  is_single_emoji = False
198
  if "<|user|>" in ch or is_single_emoji:
@@ -214,7 +262,6 @@ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
214
  yield chunk.detach().cpu().numpy().squeeze().tobytes()
215
  except RuntimeError as e:
216
  print(f"Error during TTS inference: {e}")
217
- # Soft-restart if GPU went bad and we can talk to the HF API
218
  if "device-side assert" in str(e) and api:
219
  gr.Warning("Critical GPU error. Attempting to restart the Space...")
220
  try:
@@ -223,38 +270,34 @@ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
223
  pass
224
 
225
  # ===================================================================================
226
- # 4. MAIN GRADIO FUNCTION (Decorated for ZeroGPU)
227
  # ===================================================================================
228
 
229
- @spaces.GPU(duration=120) # Request GPU for 120 seconds
230
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
231
  if secret_token_input != SECRET_TOKEN:
232
  raise gr.Error("Invalid secret token provided.")
233
  if not input_text:
234
  return []
235
 
236
- # Load models
237
- tts, llm = load_models()
238
-
239
- # Pre-compute voice latents
240
- latent_map: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
241
- for role, filename in [
242
- ("Cloée", "cloee-1.wav"),
243
- ("Julian", "julian-bedtime-style-1.wav"),
244
- ("Pirate", "pirate_by_coqui.wav"),
245
- ("Thera", "thera-1.wav"),
246
- ]:
247
- path = os.path.join("voices", filename)
248
- latent_map[role] = tts.get_conditioning_latents(
249
- audio_path=path, gpt_cond_len=30, max_ref_length=60
250
- )
251
 
252
  # Generate story text
253
  history: List[Tuple[str, str | None]] = [(input_text, None)]
254
  full_story_text = "".join(
255
- generate_text_stream(llm, history[-1][0], history[:-1], system_message=ROLE_PROMPTS[chatbot_role])
256
  ).strip()
257
-
258
  if not full_story_text:
259
  return []
260
 
@@ -267,7 +310,7 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
267
  if not any(c.isalnum() for c in sentence):
268
  continue
269
 
270
- audio_chunks = generate_audio_stream(tts, sentence, lang, latent_map[chatbot_role])
271
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
272
 
273
  # Optional noise reduction (best-effort)
@@ -285,43 +328,38 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
285
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
286
  results.append({"text": sentence, "audio": b64_wav})
287
 
 
 
 
 
 
 
288
  return results
289
 
290
  # ===================================================================================
291
- # 5. GRADIO INTERFACE LAUNCH
292
  # ===================================================================================
293
 
294
- # Download voice files on startup
295
- print("Downloading voice files...")
296
- file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
297
- base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
298
- os.makedirs("voices", exist_ok=True)
299
- for name in file_names:
300
- dst = os.path.join("voices", name)
301
- if not os.path.exists(dst):
302
- try:
303
- resp = requests.get(base_url + name, timeout=30)
304
- resp.raise_for_status()
305
- with open(dst, "wb") as f:
306
- f.write(resp.content)
307
- except Exception as e:
308
- print(f"Failed to download {name}: {e}")
309
-
310
- # Define the Gradio Interface
311
- demo = gr.Interface(
312
- fn=generate_story_and_speech,
313
- inputs=[
314
- gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN),
315
- gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"),
316
- gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"),
317
- ],
318
- outputs=gr.JSON(label="Story and Audio Output"),
319
- title="AI Storyteller with ZeroGPU",
320
- description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
321
- allow_flagging="never",
322
- analytics_enabled=False, # <- keep analytics off to avoid pandas 2.x requirement
323
- )
324
 
325
  if __name__ == "__main__":
326
- # For Spaces or Docker, these defaults are handy; adjust as needed.
 
 
 
 
 
 
327
  demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
1
  # ===================================================================================
2
+ # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
  import os
 
6
  import base64
7
  import struct
8
  import re
9
  import textwrap
10
+ import requests
11
  from typing import List, Dict, Tuple, Generator
12
 
13
+ # --- Fast, safe defaults ---
14
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
+ os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics
18
 
19
+ # --- Load .env early (HF_TOKEN / SECRET_TOKEN) ---
20
  from dotenv import load_dotenv
21
  load_dotenv()
22
 
 
24
  try:
25
  import spaces # Required for ZeroGPU on HF
26
  except Exception:
 
27
  class _SpacesShim:
28
  def GPU(self, *args, **kwargs):
29
  def _wrap(fn):
 
52
  import noisereduce as nr
53
 
54
  # ===================================================================================
55
+ # 2) GLOBALS & HELPERS
56
  # ===================================================================================
57
 
58
+ # Download NLTK data (punkt) once
59
  nltk.download("punkt", quiet=True)
60
 
61
+ # Cached models & latents
 
 
62
  tts_model: Xtts | None = None
63
  llm_model: Llama | None = None
64
+ voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
65
 
66
+ # Config
67
  HF_TOKEN = os.environ.get("HF_TOKEN")
68
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
69
  repo_id = "ruslanmv/ai-story-server"
 
84
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
85
  )
86
 
87
+ # ---------- small utils ----------
88
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
89
  if pcm_data.startswith(b"RIFF"):
90
  return pcm_data
 
118
  return prompt
119
 
120
  # ===================================================================================
121
+ # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
122
  # ===================================================================================
123
 
124
+ def precache_assets() -> None:
125
+ """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
126
+ # Voices
127
+ print("Pre-caching voice files...")
128
+ file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
129
+ base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
130
+ os.makedirs("voices", exist_ok=True)
131
+ for name in file_names:
132
+ dst = os.path.join("voices", name)
133
+ if not os.path.exists(dst):
134
+ try:
135
+ resp = requests.get(base_url + name, timeout=30)
136
+ resp.raise_for_status()
137
+ with open(dst, "wb") as f:
138
+ f.write(resp.content)
139
+ except Exception as e:
140
+ print(f"Failed to download {name}: {e}")
141
+
142
+ # XTTS model files
143
+ print("Pre-caching XTTS v2 model files...")
144
+ ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
145
+
146
+ # LLM GGUF
147
+ print("Pre-caching Zephyr GGUF...")
148
+ try:
149
+ hf_hub_download(
150
+ repo_id="TheBloke/zephyr-7B-beta-GGUF",
151
+ filename="zephyr-7b-beta.Q5_K_M.gguf",
152
+ force_download=False
153
+ )
154
+ except Exception as e:
155
+ print(f"Warning: GGUF pre-cache error: {e}")
156
+
157
  def _load_xtts(device: str) -> Xtts:
158
+ """Load XTTS from the local cache. Use checkpoint_dir to avoid None path bug."""
159
+ print("Loading Coqui XTTS V2 model (CPU first)...")
160
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
161
+ ModelManager().download_model(model_name) # idempotent
162
+ model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
163
+
164
+ cfg = XttsConfig()
165
+ cfg.load_json(os.path.join(model_dir, "config.json"))
166
+ model = Xtts.init_from_config(cfg)
167
 
168
+ # IMPORTANT: use checkpoint_dir (fixes speakers file path resolution)
 
 
 
169
  model.load_checkpoint(
170
+ cfg,
171
+ checkpoint_dir=model_dir,
 
172
  eval=True,
173
+ use_deepspeed=False, # deepspeed not installed in Spaces
174
  )
175
  model.to(device)
176
  print("XTTS model loaded.")
177
  return model
178
 
179
  def _load_llama() -> Llama:
180
+ """Load Llama (Zephyr GGUF) on CPU so it's ready immediately."""
181
+ print("Loading LLM (Zephyr GGUF) on CPU...")
182
  zephyr_model_path = hf_hub_download(
183
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
184
  filename="zephyr-7b-beta.Q5_K_M.gguf"
185
  )
186
+ # Initialize CPU instance (n_gpu_layers=0). If you want GPU offload, you can
187
+ # create a second instance inside the GPU window, but CPU is simpler & ready now.
188
+ llm = Llama(
189
+ model_path=zephyr_model_path,
190
+ n_gpu_layers=0, # CPU by default to keep it ready without GPU
191
+ n_ctx=4096,
192
+ n_batch=512,
193
+ verbose=False
194
+ )
195
+ print("LLM loaded (CPU).")
196
+ return llm
197
+
198
+ def init_models_and_latents() -> None:
199
+ """Preload TTS and LLM on CPU and compute voice latents once."""
200
+ global tts_model, llm_model, voice_latents
 
 
 
 
 
 
201
  device = "cuda" if torch.cuda.is_available() else "cpu"
202
+
203
  if tts_model is None:
204
+ tts_model = _load_xtts(device="cpu") # keep on CPU at startup
205
+
206
  if llm_model is None:
207
  llm_model = _load_llama()
208
+
209
+ # Pre-compute latents once (CPU OK)
210
+ if not voice_latents:
211
+ print("Computing voice conditioning latents...")
212
+ for role, filename in [
213
+ ("Cloée", "cloee-1.wav"),
214
+ ("Julian", "julian-bedtime-style-1.wav"),
215
+ ("Pirate", "pirate_by_coqui.wav"),
216
+ ("Thera", "thera-1.wav"),
217
+ ]:
218
+ path = os.path.join("voices", filename)
219
+ voice_latents[role] = tts_model.get_conditioning_latents(
220
+ audio_path=path, gpt_cond_len=30, max_ref_length=60
221
+ )
222
+ print("Voice latents ready.")
223
+
224
+ # ===================================================================================
225
+ # 4) INFERENCE HELPERS
226
+ # ===================================================================================
227
 
228
  def generate_text_stream(llm_instance: Llama, prompt: str,
229
  history: List[Tuple[str, str | None]],
230
+ system_message_text: str) -> Generator[str, None, None]:
231
+ formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text)
232
  stream = llm_instance(
233
  formatted_prompt,
234
  temperature=0.7,
 
239
  )
240
  for response in stream:
241
  ch = response["choices"][0]["text"]
 
242
  try:
243
+ is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
244
  except Exception:
245
  is_single_emoji = False
246
  if "<|user|>" in ch or is_single_emoji:
 
262
  yield chunk.detach().cpu().numpy().squeeze().tobytes()
263
  except RuntimeError as e:
264
  print(f"Error during TTS inference: {e}")
 
265
  if "device-side assert" in str(e) and api:
266
  gr.Warning("Critical GPU error. Attempting to restart the Space...")
267
  try:
 
270
  pass
271
 
272
  # ===================================================================================
273
+ # 5) ZERO-GPU ENTRYPOINT
274
  # ===================================================================================
275
 
276
+ @spaces.GPU(duration=120) # Request GPU for 120s (can tune later)
277
  def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
278
  if secret_token_input != SECRET_TOKEN:
279
  raise gr.Error("Invalid secret token provided.")
280
  if not input_text:
281
  return []
282
 
283
+ # Models & latents are preloaded at startup; ensure available
284
+ if tts_model is None or llm_model is None or not voice_latents:
285
+ init_models_and_latents()
286
+
287
+ # If ZeroGPU provided a GPU for this call, move XTTS to CUDA for faster audio
288
+ try:
289
+ if torch.cuda.is_available():
290
+ tts_model.to("cuda")
291
+ else:
292
+ tts_model.to("cpu")
293
+ except Exception:
294
+ tts_model.to("cpu")
 
 
 
295
 
296
  # Generate story text
297
  history: List[Tuple[str, str | None]] = [(input_text, None)]
298
  full_story_text = "".join(
299
+ generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
300
  ).strip()
 
301
  if not full_story_text:
302
  return []
303
 
 
310
  if not any(c.isalnum() for c in sentence):
311
  continue
312
 
313
+ audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role])
314
  pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
315
 
316
  # Optional noise reduction (best-effort)
 
328
  b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
329
  results.append({"text": sentence, "audio": b64_wav})
330
 
331
+ # Return XTTS to CPU to free GPU instantly after the call
332
+ try:
333
+ tts_model.to("cpu")
334
+ except Exception:
335
+ pass
336
+
337
  return results
338
 
339
  # ===================================================================================
340
+ # 6) STARTUP: PRECACHE & UI
341
  # ===================================================================================
342
 
343
+ def build_ui() -> gr.Interface:
344
+ return gr.Interface(
345
+ fn=generate_story_and_speech,
346
+ inputs=[
347
+ gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN),
348
+ gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"),
349
+ gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"),
350
+ ],
351
+ outputs=gr.JSON(label="Story and Audio Output"),
352
+ title="AI Storyteller with ZeroGPU",
353
+ description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
354
+ flagging_mode="never", # replaces deprecated allow_flagging
355
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  if __name__ == "__main__":
358
+ print("===== Startup: pre-cache assets and preload models =====")
359
+ precache_assets() # 1) download everything to disk
360
+ init_models_and_latents() # 2) load models on CPU + compute voice latents
361
+ print("Models and assets ready. Launching UI...")
362
+
363
+ demo = build_ui()
364
+ # queue + analytics disabled (env) keeps pandas out of the path
365
  demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))