ruslanmv commited on
Commit
207e4c3
·
1 Parent(s): 7741539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -189
app.py CHANGED
@@ -1,211 +1,168 @@
1
  # ===================================================================================
2
- # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
- import os, base64, struct, textwrap, re
6
  import requests
7
- from typing import List, Tuple, Dict, Generator
8
-
9
- # Load .env early (HF_TOKEN / SECRET_TOKEN)
 
 
 
 
 
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
13
- # Fast downloads & stable behavior
14
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # faster HF downloads
15
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
- os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # avoid pandas analytics path
18
- os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
19
-
20
- # HF Spaces / Gradio
21
  try:
22
- import spaces # ZeroGPU decorator
23
  except Exception:
 
24
  class _SpacesShim:
25
- def GPU(self, *a, **k):
26
- def _wrap(fn): return fn
 
27
  return _wrap
28
  spaces = _SpacesShim()
29
 
30
  import gradio as gr
31
 
32
- # Core ML
33
  import torch
34
  import numpy as np
35
  from huggingface_hub import HfApi, hf_hub_download
36
  from llama_cpp import Llama
37
 
38
- # Coqui TTS (XTTS v2)
39
  from TTS.tts.configs.xtts_config import XttsConfig
40
  from TTS.tts.models.xtts import Xtts
41
  from TTS.utils.manage import ModelManager
42
  from TTS.utils.generic_utils import get_user_data_dir
43
 
44
- # Text / audio processing
45
- import nltk, langid, emoji, noisereduce as nr
46
-
47
- # Download NLTK data once
48
- nltk.download("punkt", quiet=True)
49
 
50
  # ===================================================================================
51
- # 2) GLOBALS & HELPERS
52
  # ===================================================================================
53
- HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
54
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
55
  repo_id = "ruslanmv/ai-story-server"
56
-
57
  SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
58
  SENTENCE_SPLIT_LENGTH = 250
59
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
60
 
61
- # Cached models & latents
62
- tts_model: Xtts | None = None
63
- llm_model: Llama | None = None
64
- voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
65
-
66
- ROLES = ["Cloée", "Julian", "Pirate", "Thera"]
67
  default_system_message = (
68
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
69
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
70
  )
71
  system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message)
72
- ROLE_PROMPTS = {r: system_message for r in ROLES}
 
73
  ROLE_PROMPTS["Pirate"] = (
74
  "You are AI Beard, a pirate. Craft your response from his first-person perspective. "
75
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
76
  )
77
 
78
- def pcm_to_wav(pcm: bytes, sr: int = 24000, ch: int = 1, bit: int = 16) -> bytes:
79
- if pcm.startswith(b"RIFF"): # already WAV
80
- return pcm
81
- chunk = 36 + len(pcm)
82
- hdr = struct.pack(
 
83
  "<4sI4s4sIHHIIHH4sI",
84
- b"RIFF", chunk, b"WAVE", b"fmt ", 16, 1, ch, sr,
85
- sr * ch * bit // 8, ch * bit // 8, bit, b"data", len(pcm)
 
 
 
86
  )
87
- return hdr + pcm
88
 
89
  def split_sentences(text: str, max_len: int) -> List[str]:
90
- out: List[str] = []
91
- for sent in nltk.sent_tokenize(text):
 
92
  if len(sent) > max_len:
93
- out.extend(textwrap.wrap(sent, max_len, break_long_words=True))
94
  else:
95
- out.append(sent)
96
- return out
97
-
98
- def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sys_msg: str) -> str:
99
- prompt = f"<|system|>\n{sys_msg}</s>"
100
- for u, a in history:
101
- if a:
102
- prompt += f"<|user|>\n{u}</s><|assistant|>\n{a}</s>"
103
  prompt += f"<|user|>\n{message}</s><|assistant|>"
104
  return prompt
105
 
106
  # ===================================================================================
107
- # 3) PRE-CACHE (FIRST-RUN DOWNLOADS ONLY)
108
- # ===================================================================================
109
-
110
- def _xtts_paths() -> Tuple[str, str, str, str]:
111
- """
112
- Returns (model_dir, model_pth, vocab_json, speakers_pth) for XTTS v2.
113
- Ensures the model is downloaded.
114
- """
115
- model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
116
- ModelManager().download_model(model_name) # idempotent
117
- model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
118
- return (
119
- model_dir,
120
- os.path.join(model_dir, "model.pth"),
121
- os.path.join(model_dir, "vocab.json"),
122
- os.path.join(model_dir, "speakers_xtts.pth"),
123
- )
124
-
125
- def precache_assets() -> None:
126
- """Download all large artifacts so the first inference is fast."""
127
- # Voices
128
- print("Pre-caching voice files...")
129
- base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
130
- os.makedirs("voices", exist_ok=True)
131
- for name in ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]:
132
- dst = os.path.join("voices", name)
133
- if not os.path.exists(dst):
134
- try:
135
- r = requests.get(base_url + name, timeout=30)
136
- r.raise_for_status()
137
- with open(dst, "wb") as f:
138
- f.write(r.content)
139
- except Exception as e:
140
- print(f"Warning: failed to fetch {name}: {e}")
141
-
142
- # XTTS weights (CPU-safe: just files)
143
- print("Pre-caching XTTS model files...")
144
- model_dir, model_pth, vocab_json, speakers_pth = _xtts_paths()
145
- for p in [model_pth, vocab_json, speakers_pth, os.path.join(model_dir, "config.json")]:
146
- if not os.path.exists(p):
147
- print(f"Warning: missing expected XTTS file: {p}")
148
-
149
- # Llama GGUF
150
- print("Pre-caching LLM (Zephyr GGUF)...")
151
- try:
152
- hf_hub_download(
153
- repo_id="TheBloke/zephyr-7B-beta-GGUF",
154
- filename="zephyr-7b-beta.Q5_K_M.gguf",
155
- force_download=False
156
- )
157
- except Exception as e:
158
- print(f"Warning: GGUF download error: {e}")
159
-
160
- # Run pre-cache at import time (downloads only; no GPU needed)
161
- precache_assets()
162
-
163
- # ===================================================================================
164
- # 4) MODEL LOADERS
165
  # ===================================================================================
166
 
167
  def _load_xtts(device: str) -> Xtts:
168
  print("Loading Coqui XTTS V2 model (first run)...")
169
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
170
- ModelManager().download_model(model_name) # idempotent
171
- model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
172
 
173
- cfg = XttsConfig()
174
- cfg.load_json(os.path.join(model_dir, "config.json"))
175
-
176
- model = Xtts.init_from_config(cfg)
177
- # Use checkpoint_dir so the library finds model.pth, vocab.json and speakers_xtts.pth itself
178
  model.load_checkpoint(
179
- cfg,
180
- checkpoint_dir=model_dir,
 
181
  eval=True,
182
- use_deepspeed=False, # deepspeed not installed in your Space
183
  )
184
  model.to(device)
185
- print("XTTS model ready.")
186
  return model
187
 
188
  def _load_llama() -> Llama:
189
  print("Loading LLM (Zephyr) (first run)...")
190
- gguf = hf_hub_download(
191
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
192
  filename="zephyr-7b-beta.Q5_K_M.gguf"
193
  )
194
- # Try GPU offload then CPU
195
  for n_gpu_layers in (-1, 0):
196
  try:
197
  llm = Llama(
198
- model_path=gguf,
199
  n_gpu_layers=n_gpu_layers,
200
  n_ctx=4096,
201
  n_batch=512,
202
  verbose=False
203
  )
204
- print("LLM loaded with " + ("GPU offload" if n_gpu_layers == -1 else "CPU"))
 
 
 
205
  return llm
206
  except Exception as e:
207
- print(f"LLM init failed (n_gpu_layers={n_gpu_layers}): {e}")
208
- raise RuntimeError("Failed to initialize Llama.")
209
 
210
  def load_models() -> Tuple[Xtts, Llama]:
211
  global tts_model, llm_model
@@ -216,118 +173,134 @@ def load_models() -> Tuple[Xtts, Llama]:
216
  llm_model = _load_llama()
217
  return tts_model, llm_model
218
 
219
- # ===================================================================================
220
- # 5) GENERATION
221
- # ===================================================================================
222
-
223
- def generate_text_stream(llm: Llama, prompt: str,
224
  history: List[Tuple[str, str | None]],
225
- sys_msg: str) -> Generator[str, None, None]:
226
- formatted = format_prompt_zephyr(prompt, history, sys_msg)
227
- stream = llm(
228
- formatted,
229
  temperature=0.7,
230
  max_tokens=512,
231
  top_p=0.95,
232
  stop=LLM_STOP_WORDS,
233
  stream=True
234
  )
235
- for resp in stream:
236
- ch = resp["choices"][0]["text"]
 
237
  if "<|user|>" in ch or (len(ch) == 1 and emoji.is_emoji(ch)):
238
  continue
239
  yield ch
240
 
241
- def generate_audio_stream(tts: Xtts, text: str, lang: str,
242
  latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
243
- gpt_lat, spk_emb = latents
244
  try:
245
- for chunk in tts.inference_stream(
246
  text=text,
247
- language=lang,
248
- gpt_cond_latent=gpt_lat,
249
- speaker_embedding=spk_emb,
250
  temperature=0.85,
251
  ):
252
  if chunk is not None:
253
  yield chunk.detach().cpu().numpy().squeeze().tobytes()
254
  except RuntimeError as e:
255
- print(f"TTS inference error: {e}")
 
256
  if "device-side assert" in str(e) and api:
 
257
  try:
258
- gr.Warning("Critical GPU error. Attempting to restart the Space...")
259
  api.restart_space(repo_id=repo_id)
260
- except Exception:
261
  pass
262
 
263
  # ===================================================================================
264
- # 6) ZERO-GPU MAIN FUNCTION
265
  # ===================================================================================
266
 
267
- @spaces.GPU(duration=120)
268
- def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str):
269
  if secret_token_input != SECRET_TOKEN:
270
  raise gr.Error("Invalid secret token provided.")
271
  if not input_text:
272
  return []
273
 
 
274
  tts, llm = load_models()
275
 
276
- # Pre-compute & cache voice latents once per worker
277
- global voice_latents
278
- if not voice_latents:
279
- for role, fname in [
280
- ("Cloée", "cloee-1.wav"),
281
- ("Julian", "julian-bedtime-style-1.wav"),
282
- ("Pirate", "pirate_by_coqui.wav"),
283
- ("Thera", "thera-1.wav"),
284
- ]:
285
- path = os.path.join("voices", fname)
286
- voice_latents[role] = tts.get_conditioning_latents(
287
- audio_path=path, gpt_cond_len=30, max_ref_length=60
288
- )
 
 
 
 
 
289
 
290
- # Generate story
291
- history = [(input_text, None)]
292
- story = "".join(generate_text_stream(llm, history[-1][0], history[:-1], ROLE_PROMPTS[chatbot_role])).strip()
293
- if not story:
294
  return []
295
 
296
- # Clean & split
297
- story = re.sub(r"([^\x00-\x7F]|\w)([.?!]+)", r"\1 \2", story)
298
- sentences = split_sentences(story, SENTENCE_SPLIT_LENGTH)
299
  lang = langid.classify(sentences[0])[0] if sentences else "en"
300
 
301
- results = []
302
- for s in sentences:
303
- if not any(c.isalnum() for c in s):
304
  continue
305
 
306
- pcm_chunks = generate_audio_stream(tts, s, lang, voice_latents[chatbot_role])
307
- pcm = b"".join(ch for ch in pcm_chunks if ch)
308
 
309
- # Best-effort noise reduction
310
  try:
311
- arr = np.frombuffer(pcm, dtype=np.int16)
312
- if arr.size:
313
- wav_f32 = arr.astype(np.float32) / 32767.0
314
- denoised = nr.reduce_noise(y=wav_f32, sr=24000)
315
- pcm = (denoised * 32767).astype(np.int16).tobytes()
 
 
316
  except Exception:
317
- pass
318
 
319
- b64 = base64.b64encode(pcm_to_wav(pcm)).decode("utf-8")
320
- results.append({"text": s, "audio": b64})
321
 
322
  return results
323
 
324
  # ===================================================================================
325
- # 7) UI
326
  # ===================================================================================
327
 
328
- print("Downloading voice files (idempotent)...")
329
- # Already handled in precache, but keep for local dev logs
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
 
331
  demo = gr.Interface(
332
  fn=generate_story_and_speech,
333
  inputs=[
@@ -338,8 +311,8 @@ demo = gr.Interface(
338
  outputs=gr.JSON(label="Story and Audio Output"),
339
  title="AI Storyteller with ZeroGPU",
340
  description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
341
- flagging_mode="never", # replaces deprecated allow_flagging
342
  )
343
 
344
  if __name__ == "__main__":
345
- demo.queue().launch(analytics_enabled=False)
 
1
  # ===================================================================================
2
+ # 1. SETUP AND IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
+ import os
6
  import requests
7
+ import base64
8
+ import struct
9
+ import re
10
+ import textwrap
11
+ import uuid
12
+ from typing import List, Dict, Tuple, Generator
13
+
14
+ # --- Load .env early (for HF_TOKEN / SECRET_TOKEN) ---
15
  from dotenv import load_dotenv
16
  load_dotenv()
17
 
18
+ # --- Hugging Face Spaces & ZeroGPU ---
 
 
 
 
 
 
 
19
  try:
20
+ import spaces # Required for ZeroGPU on HF
21
  except Exception:
22
+ # Allow local runs without the spaces package
23
  class _SpacesShim:
24
+ def GPU(self, *args, **kwargs):
25
+ def _wrap(fn):
26
+ return fn
27
  return _wrap
28
  spaces = _SpacesShim()
29
 
30
  import gradio as gr
31
 
32
+ # --- Core ML & Data Libraries ---
33
  import torch
34
  import numpy as np
35
  from huggingface_hub import HfApi, hf_hub_download
36
  from llama_cpp import Llama
37
 
38
+ # --- TTS Libraries ---
39
  from TTS.tts.configs.xtts_config import XttsConfig
40
  from TTS.tts.models.xtts import Xtts
41
  from TTS.utils.manage import ModelManager
42
  from TTS.utils.generic_utils import get_user_data_dir
43
 
44
+ # --- Text & Audio Processing ---
45
+ import nltk
46
+ import langid
47
+ import emoji
48
+ import noisereduce as nr
49
 
50
  # ===================================================================================
51
+ # 2. GLOBAL CONFIGURATION & HELPER FUNCTIONS
52
  # ===================================================================================
53
+
54
+ # Download NLTK data (punkt)
55
+ nltk.download("punkt", quiet=True)
56
+
57
+ os.environ["COQUI_TOS_AGREED"] = "1"
58
+
59
+ # Cached models
60
+ tts_model: Xtts | None = None
61
+ llm_model: Llama | None = None
62
+
63
+ # Configuration
64
+ HF_TOKEN = os.environ.get("HF_TOKEN")
65
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
66
  repo_id = "ruslanmv/ai-story-server"
 
67
  SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
68
  SENTENCE_SPLIT_LENGTH = 250
69
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
70
 
71
+ # System prompts and roles
 
 
 
 
 
72
  default_system_message = (
73
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
74
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
75
  )
76
  system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message)
77
+ ROLES = ["Cloée", "Julian", "Pirate", "Thera"]
78
+ ROLE_PROMPTS = {role: system_message for role in ROLES}
79
  ROLE_PROMPTS["Pirate"] = (
80
  "You are AI Beard, a pirate. Craft your response from his first-person perspective. "
81
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
82
  )
83
 
84
+ # --- Audio helpers ---
85
+ def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
86
+ if pcm_data.startswith(b"RIFF"):
87
+ return pcm_data
88
+ chunk_size = 36 + len(pcm_data)
89
+ header = struct.pack(
90
  "<4sI4s4sIHHIIHH4sI",
91
+ b"RIFF", chunk_size, b"WAVE", b"fmt ",
92
+ 16, 1, channels, sample_rate,
93
+ sample_rate * channels * bit_depth // 8,
94
+ channels * bit_depth // 8, bit_depth,
95
+ b"data", len(pcm_data)
96
  )
97
+ return header + pcm_data
98
 
99
  def split_sentences(text: str, max_len: int) -> List[str]:
100
+ sentences = nltk.sent_tokenize(text)
101
+ chunks: List[str] = []
102
+ for sent in sentences:
103
  if len(sent) > max_len:
104
+ chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
105
  else:
106
+ chunks.append(sent)
107
+ return chunks
108
+
109
+ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
110
+ prompt = f"<|system|>\n{system_message}</s>"
111
+ for user_prompt, bot_response in history:
112
+ if bot_response:
113
+ prompt += f"<|user|>\n{user_prompt}</s><|assistant|>\n{bot_response}</s>"
114
  prompt += f"<|user|>\n{message}</s><|assistant|>"
115
  return prompt
116
 
117
  # ===================================================================================
118
+ # 3. CORE AI FUNCTIONS (Model Loading & Inference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # ===================================================================================
120
 
121
  def _load_xtts(device: str) -> Xtts:
122
  print("Loading Coqui XTTS V2 model (first run)...")
123
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
124
+ ModelManager().download_model(model_name)
125
+ model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
126
 
127
+ config = XttsConfig()
128
+ config.load_json(os.path.join(model_path, "config.json"))
129
+ model = Xtts.init_from_config(config)
130
+ # NOTE: deepspeed not installed; keep False for Spaces
 
131
  model.load_checkpoint(
132
+ config,
133
+ checkpoint_path=os.path.join(model_path, "model.pth"),
134
+ vocab_path=os.path.join(model_path, "vocab.json"),
135
  eval=True,
136
+ use_deepspeed=False,
137
  )
138
  model.to(device)
139
+ print("XTTS model loaded.")
140
  return model
141
 
142
  def _load_llama() -> Llama:
143
  print("Loading LLM (Zephyr) (first run)...")
144
+ zephyr_model_path = hf_hub_download(
145
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
146
  filename="zephyr-7b-beta.Q5_K_M.gguf"
147
  )
148
+ # Try GPU offload if available, else CPU
149
  for n_gpu_layers in (-1, 0):
150
  try:
151
  llm = Llama(
152
+ model_path=zephyr_model_path,
153
  n_gpu_layers=n_gpu_layers,
154
  n_ctx=4096,
155
  n_batch=512,
156
  verbose=False
157
  )
158
+ if n_gpu_layers == -1:
159
+ print("LLM loaded with GPU offload.")
160
+ else:
161
+ print("LLM loaded (CPU).")
162
  return llm
163
  except Exception as e:
164
+ print(f"LLM init with n_gpu_layers={n_gpu_layers} failed: {e}")
165
+ raise RuntimeError("Failed to initialize Llama model.")
166
 
167
  def load_models() -> Tuple[Xtts, Llama]:
168
  global tts_model, llm_model
 
173
  llm_model = _load_llama()
174
  return tts_model, llm_model
175
 
176
+ def generate_text_stream(llm_instance: Llama, prompt: str,
 
 
 
 
177
  history: List[Tuple[str, str | None]],
178
+ system_message: str) -> Generator[str, None, None]:
179
+ formatted_prompt = format_prompt_zephyr(prompt, history, system_message)
180
+ stream = llm_instance(
181
+ formatted_prompt,
182
  temperature=0.7,
183
  max_tokens=512,
184
  top_p=0.95,
185
  stop=LLM_STOP_WORDS,
186
  stream=True
187
  )
188
+ for response in stream:
189
+ ch = response["choices"][0]["text"]
190
+ # Guard against control tokens & isolated emoji artefacts
191
  if "<|user|>" in ch or (len(ch) == 1 and emoji.is_emoji(ch)):
192
  continue
193
  yield ch
194
 
195
+ def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
196
  latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
197
+ gpt_cond_latent, speaker_embedding = latents
198
  try:
199
+ for chunk in tts_instance.inference_stream(
200
  text=text,
201
+ language=language,
202
+ gpt_cond_latent=gpt_cond_latent,
203
+ speaker_embedding=speaker_embedding,
204
  temperature=0.85,
205
  ):
206
  if chunk is not None:
207
  yield chunk.detach().cpu().numpy().squeeze().tobytes()
208
  except RuntimeError as e:
209
+ print(f"Error during TTS inference: {e}")
210
+ # Soft-restart if GPU went bad and we can talk to the HF API
211
  if "device-side assert" in str(e) and api:
212
+ gr.Warning("Critical GPU error. Attempting to restart the Space...")
213
  try:
 
214
  api.restart_space(repo_id=repo_id)
215
+ except Exception as _:
216
  pass
217
 
218
  # ===================================================================================
219
+ # 4. MAIN GRADIO FUNCTION (Decorated for ZeroGPU)
220
  # ===================================================================================
221
 
222
+ @spaces.GPU(duration=120) # Request GPU for 120 seconds
223
+ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
224
  if secret_token_input != SECRET_TOKEN:
225
  raise gr.Error("Invalid secret token provided.")
226
  if not input_text:
227
  return []
228
 
229
+ # Load models
230
  tts, llm = load_models()
231
 
232
+ # Pre-compute voice latents
233
+ latent_map: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
234
+ for role, filename in [
235
+ ("Cloée", "cloee-1.wav"),
236
+ ("Julian", "julian-bedtime-style-1.wav"),
237
+ ("Pirate", "pirate_by_coqui.wav"),
238
+ ("Thera", "thera-1.wav"),
239
+ ]:
240
+ path = os.path.join("voices", filename)
241
+ latent_map[role] = tts.get_conditioning_latents(
242
+ audio_path=path, gpt_cond_len=30, max_ref_length=60
243
+ )
244
+
245
+ # Generate story text
246
+ history: List[Tuple[str, str | None]] = [(input_text, None)]
247
+ full_story_text = "".join(
248
+ generate_text_stream(llm, history[-1][0], history[:-1], system_message=ROLE_PROMPTS[chatbot_role])
249
+ ).strip()
250
 
251
+ if not full_story_text:
 
 
 
252
  return []
253
 
254
+ # Tokenize into shorter sentences for TTS
255
+ sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
 
256
  lang = langid.classify(sentences[0])[0] if sentences else "en"
257
 
258
+ results: List[Dict[str, str]] = []
259
+ for sentence in sentences:
260
+ if not any(c.isalnum() for c in sentence):
261
  continue
262
 
263
+ audio_chunks = generate_audio_stream(tts, sentence, lang, latent_map[chatbot_role])
264
+ pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
265
 
266
+ # Optional noise reduction (best-effort)
267
  try:
268
+ data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
269
+ if data_s16.size > 0:
270
+ float_data = data_s16.astype(np.float32) / 32767.0
271
+ reduced = nr.reduce_noise(y=float_data, sr=24000)
272
+ final_pcm = (reduced * 32767).astype(np.int16).tobytes()
273
+ else:
274
+ final_pcm = pcm_data
275
  except Exception:
276
+ final_pcm = pcm_data
277
 
278
+ b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
279
+ results.append({"text": sentence, "audio": b64_wav})
280
 
281
  return results
282
 
283
  # ===================================================================================
284
+ # 5. GRADIO INTERFACE LAUNCH
285
  # ===================================================================================
286
 
287
+ # Download voice files on startup
288
+ print("Downloading voice files...")
289
+ file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
290
+ base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
291
+ os.makedirs("voices", exist_ok=True)
292
+ for name in file_names:
293
+ dst = os.path.join("voices", name)
294
+ if not os.path.exists(dst):
295
+ try:
296
+ resp = requests.get(base_url + name, timeout=30)
297
+ resp.raise_for_status()
298
+ with open(dst, "wb") as f:
299
+ f.write(resp.content)
300
+ except Exception as e:
301
+ print(f"Failed to download {name}: {e}")
302
 
303
+ # Define the Gradio Interface
304
  demo = gr.Interface(
305
  fn=generate_story_and_speech,
306
  inputs=[
 
311
  outputs=gr.JSON(label="Story and Audio Output"),
312
  title="AI Storyteller with ZeroGPU",
313
  description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
314
+ allow_flagging="never",
315
  )
316
 
317
  if __name__ == "__main__":
318
+ demo.queue().launch()