Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import struct
|
|
| 8 |
import textwrap
|
| 9 |
import requests
|
| 10 |
import atexit
|
|
|
|
| 11 |
from typing import List, Dict, Tuple, Generator
|
| 12 |
|
| 13 |
# --- Fast, safe defaults ---
|
|
@@ -38,8 +39,8 @@ import torch
|
|
| 38 |
import numpy as np
|
| 39 |
from huggingface_hub import HfApi, hf_hub_download
|
| 40 |
from llama_cpp import Llama
|
| 41 |
-
import torchaudio
|
| 42 |
-
import soundfile as sf
|
| 43 |
|
| 44 |
# --- TTS Libraries ---
|
| 45 |
from TTS.tts.configs.xtts_config import XttsConfig
|
|
@@ -57,15 +58,12 @@ import noisereduce as nr
|
|
| 57 |
# 2) GLOBALS & HELPERS
|
| 58 |
# ===================================================================================
|
| 59 |
|
| 60 |
-
# Download NLTK data (punkt) once
|
| 61 |
nltk.download("punkt", quiet=True)
|
| 62 |
|
| 63 |
-
# Cached models & latents
|
| 64 |
tts_model: Xtts | None = None
|
| 65 |
llm_model: Llama | None = None
|
| 66 |
voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 67 |
|
| 68 |
-
# Config
|
| 69 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 70 |
api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
|
| 71 |
repo_id = "ruslanmv/ai-story-server"
|
|
@@ -73,7 +71,6 @@ SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
|
|
| 73 |
SENTENCE_SPLIT_LENGTH = 250
|
| 74 |
LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
|
| 75 |
|
| 76 |
-
# System prompts and roles
|
| 77 |
default_system_message = (
|
| 78 |
"You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
|
| 79 |
"Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
|
|
@@ -86,7 +83,6 @@ ROLE_PROMPTS["Pirate"] = (
|
|
| 86 |
"Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
|
| 87 |
)
|
| 88 |
|
| 89 |
-
# ---------- small utils ----------
|
| 90 |
def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
|
| 91 |
if pcm_data.startswith(b"RIFF"):
|
| 92 |
return pcm_data
|
|
@@ -124,7 +120,6 @@ def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], sy
|
|
| 124 |
# ===================================================================================
|
| 125 |
|
| 126 |
def precache_assets() -> None:
|
| 127 |
-
"""Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
|
| 128 |
print("Pre-caching voice files...")
|
| 129 |
file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
|
| 130 |
base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
|
|
@@ -154,7 +149,6 @@ def precache_assets() -> None:
|
|
| 154 |
print(f"Warning: GGUF pre-cache error: {e}")
|
| 155 |
|
| 156 |
def _load_xtts(device: str) -> Xtts:
|
| 157 |
-
"""Load XTTS from the local cache."""
|
| 158 |
print("Loading Coqui XTTS V2 model (CPU first)...")
|
| 159 |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
| 160 |
model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
|
@@ -170,7 +164,6 @@ def _load_xtts(device: str) -> Xtts:
|
|
| 170 |
return model
|
| 171 |
|
| 172 |
def _load_llama() -> Llama:
|
| 173 |
-
"""Load Llama (Zephyr GGUF) on CPU so it's ready immediately."""
|
| 174 |
print("Loading LLM (Zephyr GGUF) on CPU...")
|
| 175 |
zephyr_model_path = hf_hub_download(
|
| 176 |
repo_id="TheBloke/zephyr-7B-beta-GGUF",
|
|
@@ -183,33 +176,26 @@ def _load_llama() -> Llama:
|
|
| 183 |
print("LLM loaded (CPU).")
|
| 184 |
return llm
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
"""Loads audio using soundfile, converts to a Torch tensor, and resamples if needed."""
|
| 189 |
try:
|
| 190 |
-
# Read audio file into a NumPy array
|
| 191 |
audio_np, original_sr = sf.read(path, dtype='float32')
|
| 192 |
-
|
| 193 |
-
# Ensure it's mono
|
| 194 |
if audio_np.ndim > 1:
|
| 195 |
audio_np = np.mean(audio_np, axis=1)
|
| 196 |
-
|
| 197 |
-
# Convert to a PyTorch tensor
|
| 198 |
waveform = torch.from_numpy(audio_np).float()
|
| 199 |
|
| 200 |
-
# Resample if the sample rate is not the target rate
|
| 201 |
if original_sr != target_sr:
|
| 202 |
print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.")
|
| 203 |
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
|
| 204 |
waveform = resampler(waveform)
|
| 205 |
-
|
| 206 |
-
return waveform.unsqueeze(0)
|
| 207 |
except Exception as e:
|
| 208 |
print(f"Error loading audio file {path}: {e}")
|
| 209 |
raise
|
| 210 |
|
| 211 |
def init_models_and_latents() -> None:
|
| 212 |
-
"""Preload
|
| 213 |
global tts_model, llm_model, voice_latents
|
| 214 |
|
| 215 |
if tts_model is None:
|
|
@@ -220,17 +206,28 @@ def init_models_and_latents() -> None:
|
|
| 220 |
|
| 221 |
if not voice_latents:
|
| 222 |
print("Computing voice conditioning latents...")
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
print("Voice latents ready.")
|
| 235 |
|
| 236 |
def _close_llm():
|
|
@@ -270,7 +267,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
|
|
| 270 |
if not input_text:
|
| 271 |
return []
|
| 272 |
|
| 273 |
-
# Models must be preloaded, this is a fallback.
|
| 274 |
if tts_model is None or llm_model is None:
|
| 275 |
raise gr.Error("Models not initialized. Please restart the Space.")
|
| 276 |
|
|
@@ -311,7 +307,6 @@ def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_
|
|
| 311 |
return results
|
| 312 |
|
| 313 |
finally:
|
| 314 |
-
# Crucial for ZeroGPU: ensure model returns to CPU to free the GPU
|
| 315 |
if tts_model is not None:
|
| 316 |
tts_model.to("cpu")
|
| 317 |
|
|
|
|
| 8 |
import textwrap
|
| 9 |
import requests
|
| 10 |
import atexit
|
| 11 |
+
import tempfile # <-- FIX: Import tempfile to manage temporary audio files
|
| 12 |
from typing import List, Dict, Tuple, Generator
|
| 13 |
|
| 14 |
# --- Fast, safe defaults ---
|
|
|
|
| 39 |
import numpy as np
|
| 40 |
from huggingface_hub import HfApi, hf_hub_download
|
| 41 |
from llama_cpp import Llama
|
| 42 |
+
import torchaudio
|
| 43 |
+
import soundfile as sf
|
| 44 |
|
| 45 |
# --- TTS Libraries ---
|
| 46 |
from TTS.tts.configs.xtts_config import XttsConfig
|
|
|
|
| 58 |
# 2) GLOBALS & HELPERS
|
| 59 |
# ===================================================================================
|
| 60 |
|
|
|
|
| 61 |
nltk.download("punkt", quiet=True)
|
| 62 |
|
|
|
|
| 63 |
tts_model: Xtts | None = None
|
| 64 |
llm_model: Llama | None = None
|
| 65 |
voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 66 |
|
|
|
|
| 67 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 68 |
api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
|
| 69 |
repo_id = "ruslanmv/ai-story-server"
|
|
|
|
| 71 |
SENTENCE_SPLIT_LENGTH = 250
|
| 72 |
LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
|
| 73 |
|
|
|
|
| 74 |
default_system_message = (
|
| 75 |
"You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
|
| 76 |
"Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
|
|
|
|
| 83 |
"Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
|
| 84 |
)
|
| 85 |
|
|
|
|
| 86 |
def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
|
| 87 |
if pcm_data.startswith(b"RIFF"):
|
| 88 |
return pcm_data
|
|
|
|
| 120 |
# ===================================================================================
|
| 121 |
|
| 122 |
def precache_assets() -> None:
|
|
|
|
| 123 |
print("Pre-caching voice files...")
|
| 124 |
file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
|
| 125 |
base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
|
|
|
|
| 149 |
print(f"Warning: GGUF pre-cache error: {e}")
|
| 150 |
|
| 151 |
def _load_xtts(device: str) -> Xtts:
|
|
|
|
| 152 |
print("Loading Coqui XTTS V2 model (CPU first)...")
|
| 153 |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
| 154 |
model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
|
|
|
| 164 |
return model
|
| 165 |
|
| 166 |
def _load_llama() -> Llama:
|
|
|
|
| 167 |
print("Loading LLM (Zephyr GGUF) on CPU...")
|
| 168 |
zephyr_model_path = hf_hub_download(
|
| 169 |
repo_id="TheBloke/zephyr-7B-beta-GGUF",
|
|
|
|
| 176 |
print("LLM loaded (CPU).")
|
| 177 |
return llm
|
| 178 |
|
| 179 |
+
def load_and_resample_audio(path: str, target_sr: int = 24000) -> torch.Tensor:
|
| 180 |
+
"""Loads audio, converts to a Torch tensor, and resamples if needed."""
|
|
|
|
| 181 |
try:
|
|
|
|
| 182 |
audio_np, original_sr = sf.read(path, dtype='float32')
|
|
|
|
|
|
|
| 183 |
if audio_np.ndim > 1:
|
| 184 |
audio_np = np.mean(audio_np, axis=1)
|
|
|
|
|
|
|
| 185 |
waveform = torch.from_numpy(audio_np).float()
|
| 186 |
|
|
|
|
| 187 |
if original_sr != target_sr:
|
| 188 |
print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.")
|
| 189 |
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
|
| 190 |
waveform = resampler(waveform)
|
| 191 |
+
|
| 192 |
+
return waveform.unsqueeze(0)
|
| 193 |
except Exception as e:
|
| 194 |
print(f"Error loading audio file {path}: {e}")
|
| 195 |
raise
|
| 196 |
|
| 197 |
def init_models_and_latents() -> None:
|
| 198 |
+
"""Preload models and compute voice latents, using temporary files for compatibility."""
|
| 199 |
global tts_model, llm_model, voice_latents
|
| 200 |
|
| 201 |
if tts_model is None:
|
|
|
|
| 206 |
|
| 207 |
if not voice_latents:
|
| 208 |
print("Computing voice conditioning latents...")
|
| 209 |
+
# --- FIX: Use a temporary directory to store resampled audio files ---
|
| 210 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 211 |
+
voice_files = {
|
| 212 |
+
"Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav",
|
| 213 |
+
"Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav",
|
| 214 |
+
}
|
| 215 |
+
for role, filename in voice_files.items():
|
| 216 |
+
original_path = os.path.join("voices", filename)
|
| 217 |
+
|
| 218 |
+
# 1. Load and resample audio into a tensor
|
| 219 |
+
resampled_waveform = load_and_resample_audio(original_path)
|
| 220 |
+
|
| 221 |
+
# 2. Save the corrected tensor to a temporary file
|
| 222 |
+
temp_path = os.path.join(temp_dir, f"resampled_{filename}")
|
| 223 |
+
torchaudio.save(temp_path, resampled_waveform.squeeze(0), 24000)
|
| 224 |
+
|
| 225 |
+
# 3. Pass the path of the clean, temporary file to the model
|
| 226 |
+
voice_latents[role] = tts_model.get_conditioning_latents(
|
| 227 |
+
audio_path=temp_path,
|
| 228 |
+
gpt_cond_len=30,
|
| 229 |
+
max_ref_length=60
|
| 230 |
+
)
|
| 231 |
print("Voice latents ready.")
|
| 232 |
|
| 233 |
def _close_llm():
|
|
|
|
| 267 |
if not input_text:
|
| 268 |
return []
|
| 269 |
|
|
|
|
| 270 |
if tts_model is None or llm_model is None:
|
| 271 |
raise gr.Error("Models not initialized. Please restart the Space.")
|
| 272 |
|
|
|
|
| 307 |
return results
|
| 308 |
|
| 309 |
finally:
|
|
|
|
| 310 |
if tts_model is not None:
|
| 311 |
tts_model.to("cpu")
|
| 312 |
|