# =================================================================================== # 1) SETUP & IMPORTS # =================================================================================== from __future__ import annotations import os import base64 import struct import textwrap import requests import atexit from typing import List, Dict, Tuple, Generator # --- Fast, safe defaults --- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("COQUI_TOS_AGREED", "1") os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # --- Load .env early (HF_TOKEN / SECRET_TOKEN) --- from dotenv import load_dotenv load_dotenv() # --- Hugging Face Spaces & ZeroGPU --- try: import spaces except ImportError: class _SpacesShim: def GPU(self, *args, **kwargs): def _wrap(fn): return fn return _wrap spaces = _SpacesShim() import gradio as gr # --- Core ML & Data Libraries --- import torch import numpy as np from huggingface_hub import HfApi, hf_hub_download from llama_cpp import Llama import torchaudio # Still needed for transforms, just not loading import soundfile as sf # <-- FIX: Import soundfile for robust audio loading # --- TTS Libraries --- from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.utils.manage import ModelManager from TTS.utils.generic_utils import get_user_data_dir # --- Text & Audio Processing --- import nltk import langid import emoji import noisereduce as nr # =================================================================================== # 2) GLOBALS & HELPERS # =================================================================================== # Download NLTK data (punkt) once nltk.download("punkt", quiet=True) # Cached models & latents tts_model: Xtts | None = None llm_model: Llama | None = None voice_latents: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} # Config HF_TOKEN = os.environ.get("HF_TOKEN") api = HfApi(token=HF_TOKEN) if HF_TOKEN else None repo_id = "ruslanmv/ai-story-server" SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret") SENTENCE_SPLIT_LENGTH = 250 LLM_STOP_WORDS = ["", "<|user|>", "/s>"] # System prompts and roles default_system_message = ( "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. " "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')." ) system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message) ROLES = ["Cloée", "Julian", "Pirate", "Thera"] ROLE_PROMPTS = {role: system_message for role in ROLES} ROLE_PROMPTS["Pirate"] = ( "You are AI Beard, a pirate. Craft your response from his first-person perspective. " "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak." ) # ---------- small utils ---------- def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes: if pcm_data.startswith(b"RIFF"): return pcm_data chunk_size = 36 + len(pcm_data) header = struct.pack( "<4sI4s4sIHHIIHH4sI", b"RIFF", chunk_size, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, sample_rate * channels * bit_depth // 8, channels * bit_depth // 8, bit_depth, b"data", len(pcm_data) ) return header + pcm_data def split_sentences(text: str, max_len: int) -> List[str]: sentences = nltk.sent_tokenize(text) chunks: List[str] = [] for sent in sentences: if len(sent) > max_len: chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True)) else: chunks.append(sent) return chunks def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str: prompt = f"<|system|>\n{system_message}" for user_prompt, bot_response in history: if bot_response: prompt += f"<|user|>\n{user_prompt}<|assistant|>\n{bot_response}" prompt += f"<|user|>\n{message}<|assistant|>" return prompt # =================================================================================== # 3) PRECACHE & MODEL LOADERS (RUN BEFORE FIRST INFERENCE) # =================================================================================== def precache_assets() -> None: """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference.""" print("Pre-caching voice files...") file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"] base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/" os.makedirs("voices", exist_ok=True) for name in file_names: dst = os.path.join("voices", name) if not os.path.exists(dst): try: resp = requests.get(base_url + name, timeout=30) resp.raise_for_status() with open(dst, "wb") as f: f.write(resp.content) except Exception as e: print(f"Failed to download {name}: {e}") print("Pre-caching XTTS v2 model files...") ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2") print("Pre-caching Zephyr GGUF...") try: hf_hub_download( repo_id="TheBloke/zephyr-7B-beta-GGUF", filename="zephyr-7b-beta.Q5_K_M.gguf", local_dir_use_symlinks=False, ) except Exception as e: print(f"Warning: GGUF pre-cache error: {e}") def _load_xtts(device: str) -> Xtts: """Load XTTS from the local cache.""" print("Loading Coqui XTTS V2 model (CPU first)...") model_name = "tts_models/multilingual/multi-dataset/xtts_v2" model_dir = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) if not os.path.exists(model_dir): ModelManager().download_model(model_name) cfg = XttsConfig() cfg.load_json(os.path.join(model_dir, "config.json")) model = Xtts.init_from_config(cfg) model.load_checkpoint(cfg, checkpoint_dir=model_dir, eval=True, use_deepspeed=False) model.to(device) print("XTTS model loaded.") return model def _load_llama() -> Llama: """Load Llama (Zephyr GGUF) on CPU so it's ready immediately.""" print("Loading LLM (Zephyr GGUF) on CPU...") zephyr_model_path = hf_hub_download( repo_id="TheBloke/zephyr-7B-beta-GGUF", filename="zephyr-7b-beta.Q5_K_M.gguf" ) llm = Llama( model_path=zephyr_model_path, n_gpu_layers=0, n_ctx=4096, n_batch=512, verbose=False ) print("LLM loaded (CPU).") return llm # --- FIX: Replaced torchaudio.load with soundfile.read to fix RuntimeError --- def load_audio_for_tts(path: str, target_sr: int = 24000) -> torch.Tensor: """Loads audio using soundfile, converts to a Torch tensor, and resamples if needed.""" try: # Read audio file into a NumPy array audio_np, original_sr = sf.read(path, dtype='float32') # Ensure it's mono if audio_np.ndim > 1: audio_np = np.mean(audio_np, axis=1) # Convert to a PyTorch tensor waveform = torch.from_numpy(audio_np).float() # Resample if the sample rate is not the target rate if original_sr != target_sr: print(f"Resampling audio from {original_sr}Hz to {target_sr}Hz.") resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr) waveform = resampler(waveform) return waveform.unsqueeze(0) # Add batch dimension: shape (1, T) except Exception as e: print(f"Error loading audio file {path}: {e}") raise def init_models_and_latents() -> None: """Preload TTS and LLM on CPU and compute voice latents once.""" global tts_model, llm_model, voice_latents if tts_model is None: tts_model = _load_xtts(device="cpu") if llm_model is None: llm_model = _load_llama() if not voice_latents: print("Computing voice conditioning latents...") voice_files = { "Cloée": "cloee-1.wav", "Julian": "julian-bedtime-style-1.wav", "Pirate": "pirate_by_coqui.wav", "Thera": "thera-1.wav", } for role, filename in voice_files.items(): path = os.path.join("voices", filename) # Load audio externally and pass the waveform tensor directly waveform = load_audio_for_tts(path) voice_latents[role] = tts_model.get_conditioning_latents( waveform=waveform, gpt_cond_len=30, max_ref_length=60 ) print("Voice latents ready.") def _close_llm(): global llm_model if llm_model is not None: del llm_model atexit.register(_close_llm) # =================================================================================== # 4) INFERENCE HELPERS # =================================================================================== def generate_text_stream(llm_instance: Llama, prompt: str, history: List, sys_prompt: str) -> Generator[str, None, None]: formatted_prompt = format_prompt_zephyr(prompt, history, sys_prompt) stream = llm_instance( formatted_prompt, temperature=0.7, max_tokens=512, top_p=0.95, stop=LLM_STOP_WORDS, stream=True ) for response in stream: yield response["choices"][0]["text"] def generate_audio_stream(tts_instance: Xtts, text: str, lang: str, latents: Tuple) -> Generator[bytes, None, None]: gpt_cond_latent, speaker_embedding = latents for chunk in tts_instance.inference_stream( text, lang, gpt_cond_latent, speaker_embedding, temperature=0.85, ): if chunk is not None: yield chunk.detach().cpu().numpy().squeeze().tobytes() # =================================================================================== # 5) ZERO-GPU ENTRYPOINT # =================================================================================== @spaces.GPU(duration=120) def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]: if secret_token_input != SECRET_TOKEN: raise gr.Error("Invalid secret token provided.") if not input_text: return [] # Models must be preloaded, this is a fallback. if tts_model is None or llm_model is None: raise gr.Error("Models not initialized. Please restart the Space.") try: if torch.cuda.is_available(): tts_model.to("cuda") history: List[Tuple[str, str | None]] = [(input_text, None)] full_story_text = "".join( generate_text_stream(llm_model, history[-1][0], history[:-1], ROLE_PROMPTS[chatbot_role]) ).strip() if not full_story_text: return [] sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH) lang = langid.classify(sentences[0])[0] if sentences else "en" results: List[Dict[str, str]] = [] for sentence in sentences: if not any(c.isalnum() for c in sentence): continue audio_chunks = generate_audio_stream(tts_model, sentence, lang, voice_latents[chatbot_role]) pcm_data = b"".join(chunk for chunk in audio_chunks if chunk) data_s16 = np.frombuffer(pcm_data, dtype=np.int16) if data_s16.size > 0: float_data = data_s16.astype(np.float32) / 32767.0 reduced = nr.reduce_noise(y=float_data, sr=24000) final_pcm = (reduced * 32767).astype(np.int16).tobytes() else: final_pcm = pcm_data b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8") results.append({"text": sentence, "audio": b64_wav}) return results finally: # Crucial for ZeroGPU: ensure model returns to CPU to free the GPU if tts_model is not None: tts_model.to("cpu") # =================================================================================== # 6) STARTUP: PRECACHE & UI # =================================================================================== def build_ui() -> gr.Blocks: with gr.Blocks() as demo: gr.Markdown("# AI Storyteller with ZeroGPU") gr.Markdown("Enter a prompt to generate a short story with voice narration using on-demand GPU.") with gr.Row(): secret_token = gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN) storyteller = gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée") prompt = gr.Textbox(placeholder="What should the story be about?", label="Story Prompt") output = gr.JSON(label="Story and Audio Output") prompt.submit( fn=generate_story_and_speech, inputs=[secret_token, prompt, storyteller], outputs=output, ) return demo if __name__ == "__main__": print("===== Startup: pre-cache assets and preload models =====") precache_assets() init_models_and_latents() print("Models and assets ready. Launching UI...") demo = build_ui() demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))