# =================================================================================== # 1) SETUP & IMPORTS # =================================================================================== from __future__ import annotations import os import sys import base64 import struct import textwrap import requests import atexit from typing import List, Dict, Tuple, Generator # --- Fast, safe defaults --- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("COQUI_TOS_AGREED", "1") os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # prevent torchaudio/ffmpeg (torio) path # --- .env early (HF_TOKEN / SECRET_TOKEN) --- from dotenv import load_dotenv load_dotenv() # --- NumPy sanity with torch 2.2.x --- import numpy as _np if int(_np.__version__.split(".", 1)[0]) >= 2: raise RuntimeError( f"Detected numpy=={_np.__version__}. Please pin numpy<2 (e.g., 1.26.4) for this Space." ) # --- Hugging Face Spaces & ZeroGPU (import BEFORE torch/diffusers) --- try: import spaces # Required for ZeroGPU on HF except Exception: class _SpacesShim: def GPU(self, *args, **kwargs): def _wrap(fn): return fn return _wrap spaces = _SpacesShim() import gradio as gr # --- Core ML & Data Libraries (after spaces import) --- import torch import numpy as np from huggingface_hub import HfApi, hf_hub_download from llama_cpp import Llama # --- Audio decode via ffmpeg-python (no torchaudio.load) --- import ffmpeg # --- TTS Libraries --- from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.utils.manage import ModelManager import TTS.tts.models.xtts as xtts_module # for monkey-patching load_audio # --- Text & Audio Processing --- import nltk import langid import emoji import noisereduce as nr # =================================================================================== # 2) GLOBALS & HELPERS # =================================================================================== # NLTK data nltk.download("punkt", quiet=True) # Cached models & latents tts_model: Xtts | None = None llm_model: Llama | None = None voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} # Config HF_TOKEN = os.environ.get("HF_TOKEN") api = HfApi(token=HF_TOKEN) if HF_TOKEN else None repo_id = "ruslanmv/ai-story-server" SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret") SENTENCE_SPLIT_LENGTH = 250 LLM_STOP_WORDS = ["", "<|user|>", "/s>"] # IMPORTANT: With ZeroGPU, DO NOT use CUDA at startup even if torch sees it. USE_STARTUP_CUDA = os.getenv("USE_STARTUP_CUDA", "false").lower() == "true" # System prompts and roles default_system_message = ( "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. " "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')." ) system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message) ROLES = ["Cloée", "Julian", "Pirate", "Thera"] ROLE_PROMPTS = {role: system_message for role in ROLES} ROLE_PROMPTS["Pirate"] = ( "You are AI Beard, a pirate. Craft your response from his first-person perspective. " "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak." ) # ---------- small utils ---------- def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes: if pcm_data.startswith(b"RIFF"): return pcm_data byte_rate = sample_rate * channels * bit_depth // 8 block_align = channels * bit_depth // 8 chunk_size = 36 + len(pcm_data) header = struct.pack( "<4sI4s4sIHHIIHH4sI", b"RIFF", chunk_size, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, byte_rate, block_align, bit_depth, b"data", len(pcm_data) ) return header + pcm_data def split_sentences(text: str, max_len: int) -> List[str]: sentences = nltk.sent_tokenize(text) out: List[str] = [] for s in sentences: if len(s) > max_len: out.extend(textwrap.wrap(s, max_len, break_long_words=True)) else: out.append(s) return out def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str: prompt = f"<|system|>\n{system_message}" for user_prompt, bot_response in history: if bot_response: prompt += f"<|user|>\n{user_prompt}<|assistant|>\n{bot_response}" prompt += f"<|user|>\n{message}<|assistant|>" return prompt # ---------- robust audio decode (mono via ffmpeg) ---------- def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray: """ Return float32 waveform in [-1, 1], mono, resampled to target_sr. Shape: (samples,) """ try: out, _ = ( ffmpeg .input(path) .output("pipe:", format="s16le", acodec="pcm_s16le", ac=1, ar=target_sr) .run(capture_stdout=True, capture_stderr=True, cmd="ffmpeg") ) pcm = np.frombuffer(out, dtype=np.int16) if pcm.size == 0: raise RuntimeError("ffmpeg produced empty audio.") return (pcm.astype(np.float32) / 32767.0) except ffmpeg.Error as e: raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e # ---------- monkey-patch XTTS internal loader to avoid torchaudio.load() ---------- def _patched_load_audio(audiopath: str, load_sr: int): """ Expected by XTTS: return torch.FloatTensor [1, samples] normalized to [-1, 1], resampled to load_sr. """ wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr) audio = torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU return audio xtts_module.load_audio = _patched_load_audio try: import TTS.utils.audio as _tts_audio_mod _tts_audio_mod.load_audio = _patched_load_audio except Exception: pass def _coqui_cache_dir() -> str: # Coqui cache default on Linux return os.path.join(os.path.expanduser("~"), ".local", "share", "tts") # =================================================================================== # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE) # =================================================================================== def precache_assets() -> None: """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference.""" print("Pre-caching voice files...") file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"] base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/" os.makedirs("voices", exist_ok=True) for name in file_names: dst = os.path.join("voices", name) if not os.path.exists(dst): try: resp = requests.get(base_url + name, timeout=30) resp.raise_for_status() with open(dst, "wb") as f: f.write(resp.content) except Exception as e: print(f"Failed to download {name}: {e}") print("Pre-caching XTTS v2 model files...") ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2") print("Pre-caching Zephyr GGUF...") try: hf_hub_download( repo_id="TheBloke/zephyr-7B-beta-GGUF", filename="zephyr-7b-beta.Q5_K_M.gguf", force_download=False ) except Exception as e: print(f"Warning: GGUF pre-cache error: {e}") def _load_xtts(device: str) -> Xtts: """Load XTTS from the local cache. Always CPU at startup for ZeroGPU compatibility.""" print(f"Loading Coqui XTTS V2 model on {device.upper()}...") model_name = "tts_models/multilingual/multi-dataset/xtts_v2" ModelManager().download_model(model_name) # idempotent model_dir = os.path.join(_coqui_cache_dir(), model_name.replace("/", "--")) cfg = XttsConfig() cfg.load_json(os.path.join(model_dir, "config.json")) model = Xtts.init_from_config(cfg) model.load_checkpoint( cfg, checkpoint_dir=model_dir, eval=True, use_deepspeed=False, ) model.to(device) print("XTTS model loaded.") return model def _load_llama_cpu_only() -> Llama: """Load Llama (Zephyr GGUF) on CPU only (ZeroGPU friendly).""" print("Loading LLM (Zephyr GGUF) on CPU...") zephyr_model_path = hf_hub_download( repo_id="TheBloke/zephyr-7B-beta-GGUF", filename="zephyr-7b-beta.Q5_K_M.gguf" ) llm = Llama( model_path=zephyr_model_path, n_gpu_layers=0, # never touch CUDA at startup n_ctx=4096, n_batch=512, verbose=False ) print("LLM loaded (CPU).") return llm def init_models_and_latents() -> None: """ Preload TTS and LLM on CPU and compute voice latents on CPU. This avoids any CUDA tensors outside the @spaces.GPU window. """ global tts_model, llm_model, voice_latents # Always CPU here (ZeroGPU rule) target_device = "cpu" if tts_model is None: tts_model = _load_xtts(device=target_device) else: tts_model.to("cpu") if llm_model is None: llm_model = _load_llama_cpu_only() # Pre-compute latents once on CPU (uses our ffmpeg loader) if not voice_latents: print("Computing voice conditioning latents (CPU)...") with torch.no_grad(): for role, filename in [ ("Cloée", "cloee-1.wav"), ("Julian", "julian-bedtime-style-1.wav"), ("Pirate", "pirate_by_coqui.wav"), ("Thera", "thera-1.wav"), ]: path = os.path.join("voices", filename) # Returns torch tensors; keep them on CPU voice_latents[role] = tts_model.get_conditioning_latents( audio_path=path, gpt_cond_len=30, max_ref_length=60 ) print("Voice latents ready (CPU).") # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown def _close_llm(): global llm_model try: if llm_model is not None: llm_model.close() except Exception: pass atexit.register(_close_llm) # =================================================================================== # 4) INFERENCE HELPERS # =================================================================================== def generate_text_stream(llm_instance: Llama, prompt: str, history: List[Tuple[str, str | None]], system_message_text: str) -> Generator[str, None, None]: formatted = format_prompt_zephyr(prompt, history, system_message_text) stream = llm_instance( formatted, temperature=0.7, max_tokens=512, top_p=0.95, stop=LLM_STOP_WORDS, stream=True ) for resp in stream: ch = resp["choices"][0]["text"] try: is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch)) except Exception: is_single_emoji = False if "<|user|>" in ch or is_single_emoji: continue yield ch def _latents_to_device(latents: Tuple[torch.Tensor, torch.Tensor], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: g, s = latents if isinstance(g, torch.Tensor): g = g.to(device) if isinstance(s, torch.Tensor): s = s.to(device) return g, s def generate_audio_stream(tts_instance: Xtts, text: str, language: str, latents: Tuple[torch.Tensor, torch.Tensor]) -> Generator[bytes, None, None]: gpt_cond_latent, speaker_embedding = latents try: for chunk in tts_instance.inference_stream( text=text, language=language, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=0.85, ): if chunk is None: continue f32 = chunk.detach().cpu().numpy().squeeze().astype(np.float32) f32 = np.clip(f32, -1.0, 1.0) s16 = (f32 * 32767.0).astype(np.int16) yield s16.tobytes() except RuntimeError as e: print(f"Error during TTS inference: {e}") if "device-side assert" in str(e) and api: try: gr.Warning("Critical GPU error. Attempting to restart the Space...") api.restart_space(repo_id=repo_id) except Exception: pass # =================================================================================== # 5) ZERO-GPU ENTRYPOINT (also works on native GPU) # =================================================================================== @spaces.GPU(duration=120) # ZeroGPU allocates a GPU only for this function call def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]: if secret_token_input != SECRET_TOKEN: raise gr.Error("Invalid secret token provided.") if not input_text: return [] # Ensure models/latents exist (CPU) if tts_model is None or llm_model is None or not voice_latents: init_models_and_latents() # If ZeroGPU granted CUDA for this call, move XTTS to CUDA; keep LLM on CPU. try: if torch.cuda.is_available(): tts_model.to("cuda") device = torch.device("cuda") else: tts_model.to("cpu") device = torch.device("cpu") except Exception: tts_model.to("cpu") device = torch.device("cpu") # Generate story text (LLM on CPU) history: List[Tuple[str, str | None]] = [(input_text, None)] full_story_text = "".join( generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role]) ).strip() if not full_story_text: return [] # Split into TTS-friendly sentences sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH) lang = langid.classify(sentences[0])[0] if sentences else "en" results: List[Dict[str, str]] = [] for sentence in sentences: if not any(c.isalnum() for c in sentence): continue # Move cached latents to the same device as the model for this call lat_dev = _latents_to_device(voice_latents[chatbot_role], device) audio_chunks = generate_audio_stream(tts_model, sentence, lang, lat_dev) pcm_data = b"".join(chunk for chunk in audio_chunks if chunk) # Optional noise reduction (best-effort, CPU) try: data_s16 = np.frombuffer(pcm_data, dtype=np.int16) if data_s16.size > 0: float_data = (data_s16.astype(np.float32) / 32767.0) reduced = nr.reduce_noise(y=float_data, sr=24000) final_pcm = np.clip(reduced * 32767.0, -32768, 32767).astype(np.int16).tobytes() else: final_pcm = pcm_data except Exception: final_pcm = pcm_data b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8") results.append({"text": sentence, "audio": b64_wav}) # Return XTTS to CPU to release GPU instantly try: tts_model.to("cpu") except Exception: pass return results # =================================================================================== # 6) STARTUP: PRECACHE & UI # =================================================================================== def build_ui() -> gr.Interface: return gr.Interface( fn=generate_story_and_speech, inputs=[ gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN), gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"), gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"), ], outputs=gr.JSON(label="Story and Audio Output"), title="AI Storyteller with ZeroGPU", description="Enter a prompt to generate a short story with voice narration using on-demand GPU.", allow_flagging="never", analytics_enabled=False, ) if __name__ == "__main__": print("===== Startup: pre-cache assets and preload models =====") print(f"Python: {sys.version.split()[0]} | Torch CUDA visible: {torch.cuda.is_available()} (will not use at startup)") precache_assets() # 1) download everything to disk init_models_and_latents() # 2) load on CPU + compute voice latents on CPU print("Models and assets ready. Launching UI...") demo = build_ui() demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))