ruslanmv commited on
Commit
fa37078
·
1 Parent(s): 2b38b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -35
app.py CHANGED
@@ -8,6 +8,7 @@ import struct
8
  import textwrap
9
  import requests
10
  import atexit
 
11
  from typing import List, Dict, Tuple, Generator
12
 
13
  # --- Fast, safe defaults ---
@@ -38,8 +39,8 @@ import torch
38
  import numpy as np
39
  from huggingface_hub import HfApi, hf_hub_download
40
  from llama_cpp import Llama
41
- import torchaudio # Still needed for transforms, just not loading
42
- import soundfile as sf # <-- FIX: Import soundfile for robust audio loading
43
 
44
  # --- TTS Libraries ---
45
  from TTS.tts.configs.xtts_config import XttsConfig
@@ -57,15 +58,12 @@ import noisereduce as nr
57
  # 2) GLOBALS & HELPERS
58
  # ===================================================================================
59
 
60
- # Download NLTK data (punkt) once
61
  nltk.download("punkt", quiet=True)
62
 
63
- # Cached models & latents
64
  tts_model: Xtts | None = None
65
  llm_model: Llama | None = None
66
  voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
67
 
68
- # Config
69
  HF_TOKEN = os.environ.get("HF_TOKEN")
70
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
71
  repo_id = "ruslanmv/ai-story-server"
@@ -73,7 +71,6 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
73
  SENTENCE_SPLIT_LENGTH = 250
74
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
75
 
76
- # System prompts and roles
77
  default_system_message = (
78
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
79
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
@@ -86,7 +83,6 @@ ROLE_PROMPTS["Pirate"] = (
86
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
87
  )
88
 
89
- # ---------- small utils ----------
90
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
91
  if pcm_data.startswith(b"RIFF"):
92
  return pcm_data
@@ -124,7 +120,6 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
124
  # ===================================================================================
125
 
126
  def precache_assets() -> None:
127
- """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
128
  print("Pre-caching voice files...")
129
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
130
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
@@ -154,7 +149,6 @@ def precache_assets() -> None:
154
  print(f"Warning: GGUF pre-cache error: {e}")
155
 
156
  def _load_xtts(device: str) -> Xtts:
157
- """Load XTTS from the local cache."""
158
  print("Loading Coqui XTTS V2 model (CPU first)...")
159
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
160
  model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
@@ -170,7 +164,6 @@ def _load_xtts(device: str) -> Xtts:
170
  return model
171
 
172
  def _load_llama() -> Llama:
173
- """Load Llama (Zephyr GGUF) on CPU so it's ready immediately."""
174
  print("Loading LLM (Zephyr GGUF) on CPU...")
175
  zephyr_model_path = hf_hub_download(
176
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
@@ -183,33 +176,26 @@ def _load_llama() -> Llama:
183
  print("LLM loaded (CPU).")
184
  return llm
185
 
186
- # --- FIX: Replaced torchaudio.load with soundfile.read to fix RuntimeError ---
187
- def load_audio_for_tts(path: str, target_sr: int = 24000) -> torch.Tensor:
188
- """Loads audio using soundfile, converts to a Torch tensor, and resamples if needed."""
189
  try:
190
- # Read audio file into a NumPy array
191
  audio_np, original_sr = sf.read(path, dtype='float32')
192
-
193
- # Ensure it's mono
194
  if audio_np.ndim > 1:
195
  audio_np = np.mean(audio_np, axis=1)
196
-
197
- # Convert to a PyTorch tensor
198
  waveform = torch.from_numpy(audio_np).float()
199
 
200
- # Resample if the sample rate is not the target rate
201
  if original_sr != target_sr:
202
  print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.")
203
  resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
204
  waveform = resampler(waveform)
205
-
206
- return waveform.unsqueeze(0) # Add batch dimension: shape (1, T)
207
  except Exception as e:
208
  print(f"Error loading audio file {path}: {e}")
209
  raise
210
 
211
  def init_models_and_latents() -> None:
212
- """Preload TTS and LLM on CPU and compute voice latents once."""
213
  global tts_model, llm_model, voice_latents
214
 
215
  if tts_model is None:
@@ -220,17 +206,28 @@ def init_models_and_latents() -> None:
220
 
221
  if not voice_latents:
222
  print("Computing voice conditioning latents...")
223
- voice_files = {
224
- "Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav",
225
- "Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav",
226
- }
227
- for role, filename in voice_files.items():
228
- path = os.path.join("voices", filename)
229
- # Load audio externally and pass the waveform tensor directly
230
- waveform = load_audio_for_tts(path)
231
- voice_latents[role] = tts_model.get_conditioning_latents(
232
- waveform=waveform, gpt_cond_len=30, max_ref_length=60
233
- )
 
 
 
 
 
 
 
 
 
 
 
234
  print("Voice latents ready.")
235
 
236
  def _close_llm():
@@ -270,7 +267,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
270
  if not input_text:
271
  return []
272
 
273
- # Models must be preloaded, this is a fallback.
274
  if tts_model is None or llm_model is None:
275
  raise gr.Error("Models not initialized. Please restart the Space.")
276
 
@@ -311,7 +307,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
311
  return results
312
 
313
  finally:
314
- # Crucial for ZeroGPU: ensure model returns to CPU to free the GPU
315
  if tts_model is not None:
316
  tts_model.to("cpu")
317
 
 
8
  import textwrap
9
  import requests
10
  import atexit
11
+ import tempfile # <-- FIX: Import tempfile to manage temporary audio files
12
  from typing import List, Dict, Tuple, Generator
13
 
14
  # --- Fast, safe defaults ---
 
39
  import numpy as np
40
  from huggingface_hub import HfApi, hf_hub_download
41
  from llama_cpp import Llama
42
+ import torchaudio
43
+ import soundfile as sf
44
 
45
  # --- TTS Libraries ---
46
  from TTS.tts.configs.xtts_config import XttsConfig
 
58
  # 2) GLOBALS & HELPERS
59
  # ===================================================================================
60
 
 
61
  nltk.download("punkt", quiet=True)
62
 
 
63
  tts_model: Xtts | None = None
64
  llm_model: Llama | None = None
65
  voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
66
 
 
67
  HF_TOKEN = os.environ.get("HF_TOKEN")
68
  api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
69
  repo_id = "ruslanmv/ai-story-server"
 
71
  SENTENCE_SPLIT_LENGTH = 250
72
  LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
73
 
 
74
  default_system_message = (
75
  "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
76
  "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
 
83
  "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
84
  )
85
 
 
86
  def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
87
  if pcm_data.startswith(b"RIFF"):
88
  return pcm_data
 
120
  # ===================================================================================
121
 
122
  def precache_assets() -> None:
 
123
  print("Pre-caching voice files...")
124
  file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
125
  base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
 
149
  print(f"Warning: GGUF pre-cache error: {e}")
150
 
151
  def _load_xtts(device: str) -> Xtts:
 
152
  print("Loading Coqui XTTS V2 model (CPU first)...")
153
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
154
  model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
 
164
  return model
165
 
166
  def _load_llama() -> Llama:
 
167
  print("Loading LLM (Zephyr GGUF) on CPU...")
168
  zephyr_model_path = hf_hub_download(
169
  repo_id="TheBloke/zephyr-7B-beta-GGUF",
 
176
  print("LLM loaded (CPU).")
177
  return llm
178
 
179
+ def load_and_resample_audio(path: str, target_sr: int = 24000) -> torch.Tensor:
180
+ """Loads audio, converts to a Torch tensor, and resamples if needed."""
 
181
  try:
 
182
  audio_np, original_sr = sf.read(path, dtype='float32')
 
 
183
  if audio_np.ndim > 1:
184
  audio_np = np.mean(audio_np, axis=1)
 
 
185
  waveform = torch.from_numpy(audio_np).float()
186
 
 
187
  if original_sr != target_sr:
188
  print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.")
189
  resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
190
  waveform = resampler(waveform)
191
+
192
+ return waveform.unsqueeze(0)
193
  except Exception as e:
194
  print(f"Error loading audio file {path}: {e}")
195
  raise
196
 
197
  def init_models_and_latents() -> None:
198
+ """Preload models and compute voice latents, using temporary files for compatibility."""
199
  global tts_model, llm_model, voice_latents
200
 
201
  if tts_model is None:
 
206
 
207
  if not voice_latents:
208
  print("Computing voice conditioning latents...")
209
+ # --- FIX: Use a temporary directory to store resampled audio files ---
210
+ with tempfile.TemporaryDirectory() as temp_dir:
211
+ voice_files = {
212
+ "Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav",
213
+ "Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav",
214
+ }
215
+ for role, filename in voice_files.items():
216
+ original_path = os.path.join("voices", filename)
217
+
218
+ # 1. Load and resample audio into a tensor
219
+ resampled_waveform = load_and_resample_audio(original_path)
220
+
221
+ # 2. Save the corrected tensor to a temporary file
222
+ temp_path = os.path.join(temp_dir, f"resampled_{filename}")
223
+ torchaudio.save(temp_path, resampled_waveform.squeeze(0), 24000)
224
+
225
+ # 3. Pass the path of the clean, temporary file to the model
226
+ voice_latents[role] = tts_model.get_conditioning_latents(
227
+ audio_path=temp_path,
228
+ gpt_cond_len=30,
229
+ max_ref_length=60
230
+ )
231
  print("Voice latents ready.")
232
 
233
  def _close_llm():
 
267
  if not input_text:
268
  return []
269
 
 
270
  if tts_model is None or llm_model is None:
271
  raise gr.Error("Models not initialized. Please restart the Space.")
272
 
 
307
  return results
308
 
309
  finally:
 
310
  if tts_model is not None:
311
  tts_model.to("cpu")
312