ruslanmv's picture
Fixes versions
bba59ca
# ===================================================================================
# 1) SETUP & IMPORTS
# ===================================================================================
from __future__ import annotations
import os
import sys
import base64
import struct
import textwrap
import requests
import atexit
from typing import List, Dict, Tuple, Generator
# --- Fast, safe defaults ---
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("COQUI_TOS_AGREED", "1")
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")
os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path
# --- .env early (HF_TOKEN / SECRET_TOKEN) ---
from dotenv import load_dotenv
load_dotenv()
# --- NumPy sanity with torch 2.2.x ---
import numpy as _np
if int(_np.__version__.split(".", 1)[0]) >= 2:
raise RuntimeError(
f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space."
)
# --- Transformers sanity for TTS streaming ---
import transformers as _transformers
if _transformers.__version__ != "4.36.2":
raise RuntimeError(
f"Detected transformers=={_transformers.__version__}. "
"Please pin transformers==4.36.2 for compatibility with Coqui TTS streaming."
)
# --- Panda shim for Gradio on pandas<2.2 (no 'future.no_silent_downcasting') ---
try:
import pandas as pd
try:
with pd.option_context("future.no_silent_downcasting", True):
pass
except Exception:
from contextlib import contextmanager
_orig_option_context = pd.option_context
@contextmanager
def _patched_option_context(*args, **kwargs):
# filter out unsupported option pairs
filtered = []
i = 0
while i < len(args):
key = args[i]
val = args[i + 1] if i + 1 < len(args) else None
if key == "future.no_silent_downcasting":
i += 2
continue
filtered.extend([key, val])
i += 2
with _orig_option_context(*filtered, **kwargs):
yield
pd.option_context = _patched_option_context # type: ignore[attr-defined]
except Exception:
pd = None # noqa: N816
# --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) ---
try:
import spaces # Required for ZeroGPU on HF
except Exception:
class _SpacesShim:
def GPU(self, *args, **kwargs):
def _wrap(fn): return fn
return _wrap
spaces = _SpacesShim()
import gradio as gr
# --- Core ML & Data Libraries (after spaces import) ---
import torch
import numpy as np
from huggingface_hub import HfApi, hf_hub_download
from llama_cpp import Llama
# --- Audio decode via ffmpeg-python (no torchaudio.load) ---
import ffmpeg
# --- TTS Libraries ---
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.manage import ModelManager
import TTS.tts.models.xtts as xtts_module # for monkey-patching load_audio
# --- Text & Audio Processing ---
import nltk
import langid
import emoji
import noisereduce as nr
# ===================================================================================
# 2) GLOBALS & HELPERS
# ===================================================================================
# Ensure NLTK resources (both 'punkt' and new 'punkt_tab' on newer NLTK)
def _ensure_nltk():
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", quiet=True)
try:
nltk.data.find("tokenizers/punkt_tab/english")
except LookupError:
try:
nltk.download("punkt_tab", quiet=True)
except Exception:
# fallback: downloading 'punkt' already satisfies older versions
pass
_ensure_nltk()
# Cached models & latents
tts_model: Xtts | None = None
llm_model: Llama | None = None
# store as torch.Tensors (CPU at startup)
voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
# Config
HF_TOKEN = os.environ.get("HF_TOKEN")
api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
repo_id = "ruslanmv/ai-story-server"
SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret")
SENTENCE_SPLIT_LENGTH = 250
LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"]
# IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it.
USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true"
# Roles & prompts
default_system_message = (
"You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. "
"Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')."
)
system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message)
ROLES = ["Cloée", "Julian", "Pirate", "Thera"]
ROLE_PROMPTS = {role: system_message for role in ROLES}
ROLE_PROMPTS["Pirate"] = (
"You are AI Beard, a pirate. Craft your response from his first-person perspective. "
"Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak."
)
# ---------- small utils ----------
def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes:
if pcm_data.startswith(b"RIFF"):
return pcm_data
byte_rate = sample_rate * channels * bit_depth // 8
block_align = channels * bit_depth // 8
chunk_size = 36 + len(pcm_data)
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", chunk_size, b"WAVE", b"fmt ",
16, 1, channels, sample_rate, byte_rate, block_align, bit_depth,
b"data", len(pcm_data)
)
return header + pcm_data
def split_sentences(text: str, max_len: int) -> List[str]:
sentences = nltk.sent_tokenize(text)
out: List[str] = []
for s in sentences:
if len(s) > max_len:
out.extend(textwrap.wrap(s, max_len, break_long_words=True))
else:
out.append(s)
return out
def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str:
prompt = f"<|system|>\n{system_message}</s>"
for user_prompt, bot_response in history:
if bot_response:
prompt += f"<|user|>\n{user_prompt}</s><|assistant|>\n{bot_response}</s>"
prompt += f"<|user|>\n{message}</s><|assistant|>"
return prompt
# ---------- robust audio decode (mono via ffmpeg) ----------
def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray:
"""
Return float32 waveform in [-1, 1], mono, resampled to target_sr.
Shape: (samples,)
"""
try:
out, _ = (
ffmpeg
.input(path)
.output("pipe:", format="s16le", acodec="pcm_s16le", ac=1, ar=target_sr)
.run(capture_stdout=True, capture_stderr=True, cmd="ffmpeg")
)
pcm = np.frombuffer(out, dtype=np.int16)
if pcm.size == 0:
raise RuntimeError("ffmpeg produced empty audio.")
return (pcm.astype(np.float32) / 32767.0)
except ffmpeg.Error as e:
raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e
# ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ----------
def _patched_load_audio(audiopath: str, load_sr: int):
"""
Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr.
"""
wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr)
audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU
return audio
xtts_module.load_audio = _patched_load_audio
try:
import TTS.utils.audio as _tts_audio_mod
_tts_audio_mod.load_audio = _patched_load_audio
except Exception:
pass
def _coqui_cache_dir() -> str:
# Coqui cache default on Linux
return os.path.join(os.path.expanduser("~"), ".local", "share", "tts")
# ===================================================================================
# 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE)
# ===================================================================================
def precache_assets() -> None:
"""Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference."""
print("Pre-caching voice files...")
file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"]
base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/"
os.makedirs("voices", exist_ok=True)
for name in file_names:
dst = os.path.join("voices", name)
if not os.path.exists(dst):
try:
resp = requests.get(base_url + name, timeout=30)
resp.raise_for_status()
with open(dst, "wb") as f:
f.write(resp.content)
except Exception as e:
print(f"Failed to download {name}: {e}")
print("Pre-caching XTTS v2 model files...")
ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2")
print("Pre-caching Zephyr GGUF...")
try:
hf_hub_download(
repo_id="TheBloke/zephyr-7B-beta-GGUF",
filename="zephyr-7b-beta.Q5_K_M.gguf",
force_download=False
)
except Exception as e:
print(f"Warning: GGUF pre-cache error: {e}")
def _load_xtts(device: str) -> Xtts:
"""Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility."""
print(f"Loading Coqui XTTS V2 model on {device.upper()}...")
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
ModelManager().download_model(model_name) # idempotent
model_dir = os.path.join(_coqui_cache_dir(), model_name.replace("/", "--"))
cfg = XttsConfig()
cfg.load_json(os.path.join(model_dir, "config.json"))
model = Xtts.init_from_config(cfg)
model.load_checkpoint(
cfg,
checkpoint_dir=model_dir,
eval=True,
use_deepspeed=False,
)
model.to(device)
print("XTTS model loaded.")
return model
def _load_llama_cpu_only() -> Llama:
"""Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly)."""
print("Loading LLM (Zephyr GGUF) on CPU...")
zephyr_model_path = hf_hub_download(
repo_id="TheBloke/zephyr-7B-beta-GGUF",
filename="zephyr-7b-beta.Q5_K_M.gguf"
)
llm = Llama(
model_path=zephyr_model_path,
n_gpu_layers=0, # never touch CUDA at startup
n_ctx=4096,
n_batch=512,
verbose=False
)
print("LLM loaded (CPU).")
return llm
def init_models_and_latents() -> None:
"""
Preload TTS and LLM on CPU and compute voice latents on CPU.
This avoids any CUDA tensors outside the @spaces.GPU window.
"""
global tts_model, llm_model, voice_latents
target_device = "cpu" # FORCE CPU at startup for ZeroGPU compatibility
if tts_model is None:
tts_model = _load_xtts(device=target_device)
else:
tts_model.to("cpu")
if llm_model is None:
llm_model = _load_llama_cpu_only()
# Pre-compute latents once on CPU (uses our ffmpeg loader)
if not voice_latents:
print("Computing voice conditioning latents (CPU)...")
with torch.no_grad():
for role, filename in [
("Cloée", "cloee-1.wav"),
("Julian", "julian-bedtime-style-1.wav"),
("Pirate", "pirate_by_coqui.wav"),
("Thera", "thera-1.wav"),
]:
path = os.path.join("voices", filename)
voice_latents[role] = tts_model.get_conditioning_latents(
audio_path=path, gpt_cond_len=30, max_ref_length=60
)
print("Voice latents ready (CPU).")
# Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown
def _close_llm():
global llm_model
try:
if llm_model is not None:
llm_model.close()
except Exception:
pass
atexit.register(_close_llm)
# ===================================================================================
# 4) INFERENCE HELPERS
# ===================================================================================
def generate_text_stream(llm_instance: Llama, prompt: str,
history: List[Tuple[str, str | None]],
system_message_text: str) -> Generator[str, None, None]:
formatted = format_prompt_zephyr(prompt, history, system_message_text)
stream = llm_instance(
formatted,
temperature=0.7,
max_tokens=512,
top_p=0.95,
stop=LLM_STOP_WORDS,
stream=True
)
for resp in stream:
ch = resp["choices"][0]["text"]
try:
is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch))
except Exception:
is_single_emoji = False
if "<|user|>" in ch or is_single_emoji:
continue
yield ch
def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
g, s = latents
if isinstance(g, torch.Tensor):
g = g.to(device)
if isinstance(s, torch.Tensor):
s = s.to(device)
return g, s
def generate_audio_stream(tts_instance: Xtts, text: str, language: str,
latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]:
gpt_cond_latent, speaker_embedding = latents
try:
for chunk in tts_instance.inference_stream(
text=text,
language=language,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=0.85,
):
if chunk is None:
continue
f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32)
f32 = np.clip(f32, -1.0, 1.0)
s16 = (f32 * 32767.0).astype(np.int16)
yield s16.tobytes()
except RuntimeError as e:
print(f"Error during TTS inference: {e}")
if "device-side assert" in str(e) and api:
try:
gr.Warning("Critical GPU error. Attempting to restart the Space...")
api.restart_space(repo_id=repo_id)
except Exception:
pass
# ===================================================================================
# 5) ZERO-GPU ENTRYPOINT (also works on native GPU)
# ===================================================================================
@spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call
def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]:
if secret_token_input != SECRET_TOKEN:
raise gr.Error("Invalid secret token provided.")
if not input_text:
return []
# Ensure models/latents exist (CPU)
if tts_model is None or llm_model is None or not voice_latents:
init_models_and_latents()
# If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU.
try:
if torch.cuda.is_available():
tts_model.to("cuda")
device = torch.device("cuda")
else:
tts_model.to("cpu")
device = torch.device("cpu")
except Exception:
tts_model.to("cpu")
device = torch.device("cpu")
# Generate story text (LLM on CPU)
history: List[Tuple[str, str | None]] = [(input_text, None)]
full_story_text = "".join(
generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role])
).strip()
if not full_story_text:
return []
# Split into TTS-friendly sentences
sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH)
lang = langid.classify(sentences[0])[0] if sentences else "en"
results: List[Dict[str, str]] = []
for sentence in sentences:
if not any(c.isalnum() for c in sentence):
continue
# Move cached latents to the same device as the model for this call
lat_dev = _latents_to_device(voice_latents[chatbot_role], device)
audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev)
pcm_data = b"".join(chunk for chunk in audio_chunks if chunk)
# Optional noise reduction (best-effort, CPU)
try:
data_s16 = np.frombuffer(pcm_data, dtype=np.int16)
if data_s16.size > 0:
float_data = (data_s16.astype(np.float32) / 32767.0)
reduced = nr.reduce_noise(y=float_data, sr=24000)
final_pcm = np.clip(reduced * 32767.0, -32768, 32767).astype(np.int16).tobytes()
else:
final_pcm = pcm_data
except Exception:
final_pcm = pcm_data
b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8")
results.append({"text": sentence, "audio": b64_wav})
# Return XTTS to CPU to release GPU instantly
try:
tts_model.to("cpu")
except Exception:
pass
return results
# ===================================================================================
# 6) STARTUP: PRECACHE & UI
# ===================================================================================
def build_ui() -> gr.Interface:
return gr.Interface(
fn=generate_story_and_speech,
inputs=[
gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN),
gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"),
gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"),
],
outputs=gr.JSON(label="Story and Audio Output"),
title="AI Storyteller with ZeroGPU",
description="Enter a prompt to generate a short story with voice narration using on-demand GPU.",
flagging_mode="never", # avoid deprecated allow_flagging path
analytics_enabled=False,
)
if __name__ == "__main__":
print("===== Startup: pre-cache assets and preload models =====")
print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)")
precache_assets() # 1) download everything to disk
init_models_and_latents() # 2) load on CPU + compute voice latents on CPU
print("Models and assets ready. Launching UI...")
demo = build_ui()
demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))