ruslanmv commited on
Commit
96b9f29
·
1 Parent(s): cd32542

First commit

Browse files
Files changed (2) hide show
  1. app.py +191 -164
  2. requirements.txt +6 -5
app.py CHANGED
@@ -1,168 +1,210 @@
1
  # ===================================================================================
2
- # 1. SETUP AND IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
- import os
6
  import requests
7
- import base64
8
- import struct
9
- import re
10
- import textwrap
11
- import uuid
12
- from typing import List, Dict, Tuple, Generator
13
-
14
- # --- Load .env early (for HF_TOKEN / SECRET_TOKEN) ---
15
  from dotenv import load_dotenv
16
  load_dotenv()
17
 
18
- # --- Hugging Face Spaces & ZeroGPU ---
 
 
 
 
 
 
 
19
  try:
20
- import spaces # Required for ZeroGPU on HF
21
  except Exception:
22
- # Allow local runs without the spaces package
23
  class _SpacesShim:
24
- def GPU(self, *args, **kwargs):
25
- def _wrap(fn):
26
- return fn
27
  return _wrap
28
  spaces = _SpacesShim()
29
 
30
  import gradio as gr
31
 
32
- # --- Core ML & Data Libraries ---
33
  import torch
34
  import numpy as np
35
  from huggingface_hub import HfApi, hf_hub_download
36
  from llama_cpp import Llama
37
 
38
- # --- TTS Libraries ---
39
  from TTS.tts.configs.xtts_config import XttsConfig
40
  from TTS.tts.models.xtts import Xtts
41
  from TTS.utils.manage import ModelManager
42
  from TTS.utils.generic_utils import get_user_data_dir
43
 
44
- # --- Text & Audio Processing ---
45
- import nltk
46
- import langid
47
- import emoji
48
- import noisereduce as nr
49
-
50
- # ===================================================================================
51
- # 2. GLOBAL CONFIGURATION & HELPER FUNCTIONS
52
- # ===================================================================================
53
 
54
- # Download NLTK data (punkt)
55
  nltk.download("punkt", quiet=True)
56
 
57
- os.environ["COQUI_TOS_AGREED"] = "1"
58
-
59
- # Cached models
60
- tts_model: Xtts | None = None
61
- llm_model: Llama | None = None
62
-
63
- # Configuration
64
- HF_TOKEN = os.environ.get("HF_TOKEN")
65
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
66
  repo_id = "ruslanmv/ai-story-server"
 
67
  SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
68
  SENTENCE_SPLIT_LENGTH = 250
69
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
70
 
71
- # System prompts and roles
 
 
 
 
 
72
  default_system_message = (
73
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
74
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
75
  )
76
  system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message)
77
- ROLES = ["Cloée", "Julian", "Pirate", "Thera"]
78
- ROLE_PROMPTS = {role: system_message for role in ROLES}
79
  ROLE_PROMPTS["Pirate"] = (
80
  "You are AI Beard, a pirate. Craft your response from his first-person perspective. "
81
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
82
  )
83
 
84
- # --- Audio helpers ---
85
- def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
86
- if pcm_data.startswith(b"RIFF"):
87
- return pcm_data
88
- chunk_size = 36 + len(pcm_data)
89
- header = struct.pack(
90
- "<4sI4s4sIHHIIHH4sI",
91
- b"RIFF", chunk_size, b"WAVE", b"fmt ",
92
- 16, 1, channels, sample_rate,
93
- sample_rate * channels * bit_depth // 8,
94
- channels * bit_depth // 8, bit_depth,
95
- b"data", len(pcm_data)
96
  )
97
- return header + pcm_data
98
 
99
  def split_sentences(text: str, max_len: int) -> List[str]:
100
- sentences = nltk.sent_tokenize(text)
101
- chunks: List[str] = []
102
- for sent in sentences:
103
  if len(sent) > max_len:
104
- chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True))
105
  else:
106
- chunks.append(sent)
107
- return chunks
108
-
109
- def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
110
- prompt = f"<|system|>\n{system_message}</s>"
111
- for user_prompt, bot_response in history:
112
- if bot_response:
113
- prompt += f"<|user|>\n{user_prompt}</s><|assistant|>\n{bot_response}</s>"
114
  prompt += f"<|user|>\n{message}</s><|assistant|>"
115
  return prompt
116
 
117
  # ===================================================================================
118
- # 3. CORE AI FUNCTIONS (Model Loading & Inference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # ===================================================================================
120
 
121
  def _load_xtts(device: str) -> Xtts:
122
  print("Loading Coqui XTTS V2 model (first run)...")
123
- model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
124
- ModelManager().download_model(model_name)
125
- model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
126
 
127
- config = XttsConfig()
128
- config.load_json(os.path.join(model_path, "config.json"))
129
- model = Xtts.init_from_config(config)
130
- # NOTE: deepspeed not installed; keep False for Spaces
 
131
  model.load_checkpoint(
132
- config,
133
- checkpoint_path=os.path.join(model_path, "model.pth"),
134
- vocab_path=os.path.join(model_path, "vocab.json"),
 
135
  eval=True,
136
- use_deepspeed=False,
137
  )
138
  model.to(device)
139
- print("XTTS model loaded.")
140
  return model
141
 
142
  def _load_llama() -> Llama:
143
  print("Loading LLM (Zephyr) (first run)...")
144
- zephyr_model_path = hf_hub_download(
145
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
146
  filename="zephyr-7b-beta.Q5_K_M.gguf"
147
  )
148
- # Try GPU offload if available, else CPU
149
  for n_gpu_layers in (-1, 0):
150
  try:
151
  llm = Llama(
152
- model_path=zephyr_model_path,
153
  n_gpu_layers=n_gpu_layers,
154
  n_ctx=4096,
155
  n_batch=512,
156
  verbose=False
157
  )
158
- if n_gpu_layers == -1:
159
- print("LLM loaded with GPU offload.")
160
- else:
161
- print("LLM loaded (CPU).")
162
  return llm
163
  except Exception as e:
164
- print(f"LLM init with n_gpu_layers={n_gpu_layers} failed: {e}")
165
- raise RuntimeError("Failed to initialize Llama model.")
166
 
167
  def load_models() -> Tuple[Xtts, Llama]:
168
  global tts_model, llm_model
@@ -173,134 +215,119 @@ def load_models() -> Tuple[Xtts, Llama]:
173
  llm_model = _load_llama()
174
  return tts_model, llm_model
175
 
176
- def generate_text_stream(llm_instance: Llama, prompt: str,
 
 
 
 
177
  history: List[Tuple[str, str | None]],
178
- system_message: str) -> Generator[str, None, None]:
179
- formatted_prompt = format_prompt_zephyr(prompt, history, system_message)
180
- stream = llm_instance(
181
- formatted_prompt,
182
  temperature=0.7,
183
  max_tokens=512,
184
  top_p=0.95,
185
  stop=LLM_STOP_WORDS,
186
  stream=True
187
  )
188
- for response in stream:
189
- ch = response["choices"][0]["text"]
190
- # Guard against control tokens & isolated emoji artefacts
191
  if "<|user|>" in ch or (len(ch) == 1 and emoji.is_emoji(ch)):
192
  continue
193
  yield ch
194
 
195
- def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
196
  latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
197
- gpt_cond_latent, speaker_embedding = latents
198
  try:
199
- for chunk in tts_instance.inference_stream(
200
  text=text,
201
- language=language,
202
- gpt_cond_latent=gpt_cond_latent,
203
- speaker_embedding=speaker_embedding,
204
  temperature=0.85,
205
  ):
206
  if chunk is not None:
207
  yield chunk.detach().cpu().numpy().squeeze().tobytes()
208
  except RuntimeError as e:
209
- print(f"Error during TTS inference: {e}")
210
- # Soft-restart if GPU went bad and we can talk to the HF API
211
  if "device-side assert" in str(e) and api:
212
- gr.Warning("Critical GPU error. Attempting to restart the Space...")
213
  try:
 
214
  api.restart_space(repo_id=repo_id)
215
- except Exception as _:
216
  pass
217
 
218
  # ===================================================================================
219
- # 4. MAIN GRADIO FUNCTION (Decorated for ZeroGPU)
220
  # ===================================================================================
221
 
222
- @spaces.GPU(duration=120) # Request GPU for 120 seconds
223
- def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
224
  if secret_token_input != SECRET_TOKEN:
225
  raise gr.Error("Invalid secret token provided.")
226
  if not input_text:
227
  return []
228
 
229
- # Load models
230
  tts, llm = load_models()
231
 
232
- # Pre-compute voice latents
233
- latent_map: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
234
- for role, filename in [
235
- ("Cloée", "cloee-1.wav"),
236
- ("Julian", "julian-bedtime-style-1.wav"),
237
- ("Pirate", "pirate_by_coqui.wav"),
238
- ("Thera", "thera-1.wav"),
239
- ]:
240
- path = os.path.join("voices", filename)
241
- latent_map[role] = tts.get_conditioning_latents(
242
- audio_path=path, gpt_cond_len=30, max_ref_length=60
243
- )
244
-
245
- # Generate story text
246
- history: List[Tuple[str, str | None]] = [(input_text, None)]
247
- full_story_text = "".join(
248
- generate_text_stream(llm, history[-1][0], history[:-1], system_message=ROLE_PROMPTS[chatbot_role])
249
- ).strip()
250
 
251
- if not full_story_text:
 
 
 
252
  return []
253
 
254
- # Tokenize into shorter sentences for TTS
255
- sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
 
256
  lang = langid.classify(sentences[0])[0] if sentences else "en"
257
 
258
- results: List[Dict[str, str]] = []
259
- for sentence in sentences:
260
- if not any(c.isalnum() for c in sentence):
261
  continue
262
 
263
- audio_chunks = generate_audio_stream(tts, sentence, lang, latent_map[chatbot_role])
264
- pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
265
 
266
- # Optional noise reduction (best-effort)
267
  try:
268
- data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
269
- if data_s16.size > 0:
270
- float_data = data_s16.astype(np.float32) / 32767.0
271
- reduced = nr.reduce_noise(y=float_data, sr=24000)
272
- final_pcm = (reduced * 32767).astype(np.int16).tobytes()
273
- else:
274
- final_pcm = pcm_data
275
  except Exception:
276
- final_pcm = pcm_data
277
 
278
- b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8")
279
- results.append({"text": sentence, "audio": b64_wav})
280
 
281
  return results
282
 
283
  # ===================================================================================
284
- # 5. GRADIO INTERFACE LAUNCH
285
  # ===================================================================================
286
 
287
- # Download voice files on startup
288
- print("Downloading voice files...")
289
- file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
290
- base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
291
- os.makedirs("voices", exist_ok=True)
292
- for name in file_names:
293
- dst = os.path.join("voices", name)
294
- if not os.path.exists(dst):
295
- try:
296
- resp = requests.get(base_url + name, timeout=30)
297
- resp.raise_for_status()
298
- with open(dst, "wb") as f:
299
- f.write(resp.content)
300
- except Exception as e:
301
- print(f"Failed to download {name}: {e}")
302
 
303
- # Define the Gradio Interface
304
  demo = gr.Interface(
305
  fn=generate_story_and_speech,
306
  inputs=[
@@ -311,8 +338,8 @@ demo = gr.Interface(
311
  outputs=gr.JSON(label="Story and Audio Output"),
312
  title="AI Storyteller with ZeroGPU",
313
  description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
314
- allow_flagging="never",
315
  )
316
 
317
  if __name__ == "__main__":
318
- demo.queue().launch()
 
1
  # ===================================================================================
2
+ # 1) SETUP & IMPORTS
3
  # ===================================================================================
4
  from __future__ import annotations
5
+ import os, base64, struct, textwrap, re
6
  import requests
7
+ from typing import List, Tuple, Dict, Generator
8
+
9
+ # Load .env early (HF_TOKEN / SECRET_TOKEN)
 
 
 
 
 
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
13
+ # Make downloads fast & quiet
14
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
+ os.environ.setdefault("COQUI_TOS_AGREED", "1")
17
+ # Avoid Gradio analytics pandas edge-cases
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
19
+
20
+ # HF Spaces / Gradio
21
  try:
22
+ import spaces # ZeroGPU decorator
23
  except Exception:
 
24
  class _SpacesShim:
25
+ def GPU(self, *a, **k):
26
+ def _wrap(fn): return fn
 
27
  return _wrap
28
  spaces = _SpacesShim()
29
 
30
  import gradio as gr
31
 
32
+ # Core ML
33
  import torch
34
  import numpy as np
35
  from huggingface_hub import HfApi, hf_hub_download
36
  from llama_cpp import Llama
37
 
38
+ # Coqui TTS (XTTS v2)
39
  from TTS.tts.configs.xtts_config import XttsConfig
40
  from TTS.tts.models.xtts import Xtts
41
  from TTS.utils.manage import ModelManager
42
  from TTS.utils.generic_utils import get_user_data_dir
43
 
44
+ # Text / audio processing
45
+ import nltk, langid, emoji, noisereduce as nr
 
 
 
 
 
 
 
46
 
47
+ # Download NLTK data once
48
  nltk.download("punkt", quiet=True)
49
 
50
+ # ===================================================================================
51
+ # 2) GLOBALS & HELPERS
52
+ # ===================================================================================
53
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
54
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
55
  repo_id = "ruslanmv/ai-story-server"
56
+
57
  SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
58
  SENTENCE_SPLIT_LENGTH = 250
59
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
60
 
61
+ # Cached models & latents
62
+ tts_model: Xtts | None = None
63
+ llm_model: Llama | None = None
64
+ voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
65
+
66
+ ROLES = ["Cloée", "Julian", "Pirate", "Thera"]
67
  default_system_message = (
68
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
69
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
70
  )
71
  system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message)
72
+ ROLE_PROMPTS = {r: system_message for r in ROLES}
 
73
  ROLE_PROMPTS["Pirate"] = (
74
  "You are AI Beard, a pirate. Craft your response from his first-person perspective. "
75
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
76
  )
77
 
78
+ def pcm_to_wav(pcm: bytes, sr: int = 24000, ch: int = 1, bit: int = 16) -> bytes:
79
+ if pcm.startswith(b"RIFF"): # already WAV
80
+ return pcm
81
+ chunk = 36 + len(pcm)
82
+ hdr = struct.pack("<4sI4s4sIHHIIHH4sI",
83
+ b"RIFF", chunk, b"WAVE", b"fmt ", 16, 1, ch, sr,
84
+ sr * ch * bit // 8, ch * bit // 8, bit, b"data", len(pcm)
 
 
 
 
 
85
  )
86
+ return hdr + pcm
87
 
88
  def split_sentences(text: str, max_len: int) -> List[str]:
89
+ out: List[str] = []
90
+ for sent in nltk.sent_tokenize(text):
 
91
  if len(sent) > max_len:
92
+ out.extend(textwrap.wrap(sent, max_len, break_long_words=True))
93
  else:
94
+ out.append(sent)
95
+ return out
96
+
97
+ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sys_msg: str) -> str:
98
+ prompt = f"<|system|>\n{sys_msg}</s>"
99
+ for u, a in history:
100
+ if a:
101
+ prompt += f"<|user|>\n{u}</s><|assistant|>\n{a}</s>"
102
  prompt += f"<|user|>\n{message}</s><|assistant|>"
103
  return prompt
104
 
105
  # ===================================================================================
106
+ # 3) PRE-CACHE (FIRST-RUN DOWNLOADS ONLY)
107
+ # ===================================================================================
108
+
109
+ def _xtts_paths() -> Tuple[str, str, str, str]:
110
+ """
111
+ Returns (model_dir, model_pth, vocab_json, speakers_pth) for XTTS v2.
112
+ Ensures the model is downloaded.
113
+ """
114
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
115
+ ModelManager().download_model(model_name) # idempotent
116
+ model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
117
+ return (
118
+ model_dir,
119
+ os.path.join(model_dir, "model.pth"),
120
+ os.path.join(model_dir, "vocab.json"),
121
+ os.path.join(model_dir, "speakers_xtts.pth"),
122
+ )
123
+
124
+ def precache_assets() -> None:
125
+ """Download all large artifacts so the first inference is fast."""
126
+ # Voices
127
+ print("Pre-caching voice files...")
128
+ base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
129
+ os.makedirs("voices", exist_ok=True)
130
+ for name in ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]:
131
+ dst = os.path.join("voices", name)
132
+ if not os.path.exists(dst):
133
+ try:
134
+ r = requests.get(base_url + name, timeout=30)
135
+ r.raise_for_status()
136
+ with open(dst, "wb") as f:
137
+ f.write(r.content)
138
+ except Exception as e:
139
+ print(f"Warning: failed to fetch {name}: {e}")
140
+
141
+ # XTTS weights (CPU-safe: just files)
142
+ print("Pre-caching XTTS model files...")
143
+ model_dir, model_pth, vocab_json, speakers_pth = _xtts_paths()
144
+ for p in [model_pth, vocab_json, speakers_pth, os.path.join(model_dir, "config.json")]:
145
+ if not os.path.exists(p):
146
+ print(f"Warning: missing expected XTTS file: {p}")
147
+
148
+ # Llama GGUF
149
+ print("Pre-caching LLM (Zephyr GGUF)...")
150
+ try:
151
+ hf_hub_download(
152
+ repo_id="TheBloke/zephyr-7B-beta-GGUF",
153
+ filename="zephyr-7b-beta.Q5_K_M.gguf",
154
+ force_download=False
155
+ )
156
+ except Exception as e:
157
+ print(f"Warning: GGUF download error: {e}")
158
+
159
+ # Run pre-cache at import time (downloads only; no GPU needed)
160
+ precache_assets()
161
+
162
+ # ===================================================================================
163
+ # 4) MODEL LOADERS
164
  # ===================================================================================
165
 
166
  def _load_xtts(device: str) -> Xtts:
167
  print("Loading Coqui XTTS V2 model (first run)...")
168
+ model_dir, model_pth, vocab_json, speakers_pth = _xtts_paths()
 
 
169
 
170
+ cfg = XttsConfig()
171
+ cfg.load_json(os.path.join(model_dir, "config.json"))
172
+
173
+ model = Xtts.init_from_config(cfg)
174
+ # IMPORTANT: pass speaker_file_path to avoid NoneType join inside library
175
  model.load_checkpoint(
176
+ cfg,
177
+ checkpoint_path=model_pth,
178
+ vocab_path=vocab_json,
179
+ speaker_file_path=speakers_pth, # <-- fixes TypeError
180
  eval=True,
181
+ use_deepspeed=False, # deepspeed not installed
182
  )
183
  model.to(device)
184
+ print("XTTS model ready.")
185
  return model
186
 
187
  def _load_llama() -> Llama:
188
  print("Loading LLM (Zephyr) (first run)...")
189
+ gguf = hf_hub_download(
190
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
191
  filename="zephyr-7b-beta.Q5_K_M.gguf"
192
  )
193
+ # Try GPU offload then CPU
194
  for n_gpu_layers in (-1, 0):
195
  try:
196
  llm = Llama(
197
+ model_path=gguf,
198
  n_gpu_layers=n_gpu_layers,
199
  n_ctx=4096,
200
  n_batch=512,
201
  verbose=False
202
  )
203
+ print("LLM loaded with " + ("GPU offload" if n_gpu_layers == -1 else "CPU"))
 
 
 
204
  return llm
205
  except Exception as e:
206
+ print(f"LLM init failed (n_gpu_layers={n_gpu_layers}): {e}")
207
+ raise RuntimeError("Failed to initialize Llama.")
208
 
209
  def load_models() -> Tuple[Xtts, Llama]:
210
  global tts_model, llm_model
 
215
  llm_model = _load_llama()
216
  return tts_model, llm_model
217
 
218
+ # ===================================================================================
219
+ # 5) GENERATION
220
+ # ===================================================================================
221
+
222
+ def generate_text_stream(llm: Llama, prompt: str,
223
  history: List[Tuple[str, str | None]],
224
+ sys_msg: str) -> Generator[str, None, None]:
225
+ formatted = format_prompt_zephyr(prompt, history, sys_msg)
226
+ stream = llm(
227
+ formatted,
228
  temperature=0.7,
229
  max_tokens=512,
230
  top_p=0.95,
231
  stop=LLM_STOP_WORDS,
232
  stream=True
233
  )
234
+ for resp in stream:
235
+ ch = resp["choices"][0]["text"]
 
236
  if "<|user|>" in ch or (len(ch) == 1 and emoji.is_emoji(ch)):
237
  continue
238
  yield ch
239
 
240
+ def generate_audio_stream(tts: Xtts, text: str, lang: str,
241
  latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]:
242
+ gpt_lat, spk_emb = latents
243
  try:
244
+ for chunk in tts.inference_stream(
245
  text=text,
246
+ language=lang,
247
+ gpt_cond_latent=gpt_lat,
248
+ speaker_embedding=spk_emb,
249
  temperature=0.85,
250
  ):
251
  if chunk is not None:
252
  yield chunk.detach().cpu().numpy().squeeze().tobytes()
253
  except RuntimeError as e:
254
+ print(f"TTS inference error: {e}")
 
255
  if "device-side assert" in str(e) and api:
 
256
  try:
257
+ gr.Warning("Critical GPU error. Attempting to restart the Space...")
258
  api.restart_space(repo_id=repo_id)
259
+ except Exception:
260
  pass
261
 
262
  # ===================================================================================
263
+ # 6) ZERO-GPU MAIN FUNCTION
264
  # ===================================================================================
265
 
266
+ @spaces.GPU(duration=120)
267
+ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str):
268
  if secret_token_input != SECRET_TOKEN:
269
  raise gr.Error("Invalid secret token provided.")
270
  if not input_text:
271
  return []
272
 
 
273
  tts, llm = load_models()
274
 
275
+ # Pre-compute & cache voice latents once per session
276
+ global voice_latents
277
+ if not voice_latents:
278
+ for role, fname in [
279
+ ("Cloée", "cloee-1.wav"),
280
+ ("Julian", "julian-bedtime-style-1.wav"),
281
+ ("Pirate", "pirate_by_coqui.wav"),
282
+ ("Thera", "thera-1.wav"),
283
+ ]:
284
+ path = os.path.join("voices", fname)
285
+ voice_latents[role] = tts.get_conditioning_latents(
286
+ audio_path=path, gpt_cond_len=30, max_ref_length=60
287
+ )
 
 
 
 
 
288
 
289
+ # Generate story
290
+ history = [(input_text, None)]
291
+ story = "".join(generate_text_stream(llm, history[-1][0], history[:-1], ROLE_PROMPTS[chatbot_role])).strip()
292
+ if not story:
293
  return []
294
 
295
+ # Clean & split
296
+ story = re.sub(r"([^\x00-\x7F]|\w)([.?!]+)", r"\1 \2", story)
297
+ sentences = split_sentences(story, SENTENCE_SPLIT_LENGTH)
298
  lang = langid.classify(sentences[0])[0] if sentences else "en"
299
 
300
+ results = []
301
+ for s in sentences:
302
+ if not any(c.isalnum() for c in s):
303
  continue
304
 
305
+ pcm_chunks = generate_audio_stream(tts, s, lang, voice_latents[chatbot_role])
306
+ pcm = b"".join(ch for ch in pcm_chunks if ch)
307
 
308
+ # Best-effort noise reduction
309
  try:
310
+ arr = np.frombuffer(pcm, dtype=np.int16)
311
+ if arr.size:
312
+ wav_f32 = arr.astype(np.float32) / 32767.0
313
+ denoised = nr.reduce_noise(y=wav_f32, sr=24000)
314
+ pcm = (denoised * 32767).astype(np.int16).tobytes()
 
 
315
  except Exception:
316
+ pass
317
 
318
+ b64 = base64.b64encode(pcm_to_wav(pcm)).decode("utf-8")
319
+ results.append({"text": s, "audio": b64})
320
 
321
  return results
322
 
323
  # ===================================================================================
324
+ # 7) UI
325
  # ===================================================================================
326
 
327
+ print("Downloading voice files (idempotent)...")
328
+ # Already handled in precache, but keep for local dev logs
329
+ # (No-op if files exist)
 
 
 
 
 
 
 
 
 
 
 
 
330
 
 
331
  demo = gr.Interface(
332
  fn=generate_story_and_speech,
333
  inputs=[
 
338
  outputs=gr.JSON(label="Story and Audio Output"),
339
  title="AI Storyteller with ZeroGPU",
340
  description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
341
+ flagging_mode="never", # replaces deprecated allow_flagging
342
  )
343
 
344
  if __name__ == "__main__":
345
+ demo.queue().launch() # you can add ssr_mode=False if you prefer
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- # ZeroGPU and Core
2
  torch==2.2.2
3
  torchaudio==2.2.2
4
  gradio==5.47.2
@@ -7,15 +7,16 @@ python-dotenv
7
  spaces
8
  requests
9
  numpy
 
10
 
11
- # TTS Dependencies
12
  TTS @ git+https://github.com/coqui-ai/TTS@v0.22.0
13
  pydantic==2.5.3
14
 
15
- # LLM Dependencies
16
  llama-cpp-python==0.2.79
17
 
18
- # Audio & Text Processing
19
  noisereduce==3.0.3
20
  pydub
21
  langid
@@ -23,6 +24,6 @@ nltk
23
  emoji
24
  ffmpeg-python
25
 
26
- # Japanese Text (if needed by TTS)
27
  mecab-python3==1.0.9
28
  unidic-lite==1.0.8
 
1
+ # Core
2
  torch==2.2.2
3
  torchaudio==2.2.2
4
  gradio==5.47.2
 
7
  spaces
8
  requests
9
  numpy
10
+ pandas>=2.2.2,<3 # Fixes Gradio analytics OptionError
11
 
12
+ # TTS
13
  TTS @ git+https://github.com/coqui-ai/TTS@v0.22.0
14
  pydantic==2.5.3
15
 
16
+ # LLM
17
  llama-cpp-python==0.2.79
18
 
19
+ # Audio & Text
20
  noisereduce==3.0.3
21
  pydub
22
  langid
 
24
  emoji
25
  ffmpeg-python
26
 
27
+ # Japanese Text (optional)
28
  mecab-python3==1.0.9
29
  unidic-lite==1.0.8