# =================================================================================== # 1. SETUP AND IMPORTS # =================================================================================== from __future__ import annotations import os import requests import base64 import struct import re import textwrap import uuid from typing import List, Dict, Tuple, Generator # Make sure Gradio analytics is off (so we don't need pandas 2.x) os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") # --- Load .env early (for HF_TOKEN / SECRET_TOKEN) --- from dotenv import load_dotenv load_dotenv() # --- Hugging Face Spaces & ZeroGPU --- try: import spaces # Required for ZeroGPU on HF except Exception: # Allow local runs without the spaces package class _SpacesShim: def GPU(self, *args, **kwargs): def _wrap(fn): return fn return _wrap spaces = _SpacesShim() import gradio as gr # --- Core ML & Data Libraries --- import torch import numpy as np from huggingface_hub import HfApi, hf_hub_download from llama_cpp import Llama # --- TTS Libraries --- from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.utils.manage import ModelManager from TTS.utils.generic_utils import get_user_data_dir # --- Text & Audio Processing --- import nltk import langid import emoji import noisereduce as nr # =================================================================================== # 2. GLOBAL CONFIGURATION & HELPER FUNCTIONS # =================================================================================== # Download NLTK data (punkt) nltk.download("punkt", quiet=True) os.environ["COQUI_TOS_AGREED"] = "1" # Cached models tts_model: Xtts | None = None llm_model: Llama | None = None # Configuration HF_TOKEN = os.environ.get("HF_TOKEN") api = HfApi(token=HF_TOKEN) if HF_TOKEN else None repo_id = "ruslanmv/ai-story-server" SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret") SENTENCE_SPLIT_LENGTH = 250 LLM_STOP_WORDS = ["", "<|user|>", "/s>"] # System prompts and roles default_system_message = ( "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. " "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')." ) system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message) ROLES = ["Cloée", "Julian", "Pirate", "Thera"] ROLE_PROMPTS = {role: system_message for role in ROLES} ROLE_PROMPTS["Pirate"] = ( "You are AI Beard, a pirate. Craft your response from his first-person perspective. " "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak." ) # --- Audio helpers --- def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes: if pcm_data.startswith(b"RIFF"): return pcm_data chunk_size = 36 + len(pcm_data) header = struct.pack( "<4sI4s4sIHHIIHH4sI", b"RIFF", chunk_size, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, sample_rate * channels * bit_depth // 8, channels * bit_depth // 8, bit_depth, b"data", len(pcm_data) ) return header + pcm_data def split_sentences(text: str, max_len: int) -> List[str]: sentences = nltk.sent_tokenize(text) chunks: List[str] = [] for sent in sentences: if len(sent) > max_len: chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True)) else: chunks.append(sent) return chunks def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str: prompt = f"<|system|>\n{system_message}" for user_prompt, bot_response in history: if bot_response: prompt += f"<|user|>\n{user_prompt}<|assistant|>\n{bot_response}" prompt += f"<|user|>\n{message}<|assistant|>" return prompt # =================================================================================== # 3. CORE AI FUNCTIONS (Model Loading & Inference) # =================================================================================== def _load_xtts(device: str) -> Xtts: print("Loading Coqui XTTS V2 model (first run)...") model_name = "tts_models/multilingual/multi-dataset/xtts_v2" ModelManager().download_model(model_name) model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) config = XttsConfig() config.load_json(os.path.join(model_path, "config.json")) model = Xtts.init_from_config(config) # NOTE: deepspeed not installed; keep False for Spaces model.load_checkpoint( config, checkpoint_path=os.path.join(model_path, "model.pth"), vocab_path=os.path.join(model_path, "vocab.json"), eval=True, use_deepspeed=False, ) model.to(device) print("XTTS model loaded.") return model def _load_llama() -> Llama: print("Loading LLM (Zephyr) (first run)...") zephyr_model_path = hf_hub_download( repo_id="TheBloke/zephyr-7B-beta-GGUF", filename="zephyr-7b-beta.Q5_K_M.gguf" ) # Try GPU offload if available, else CPU for n_gpu_layers in (-1, 0): try: llm = Llama( model_path=zephyr_model_path, n_gpu_layers=n_gpu_layers, n_ctx=4096, n_batch=512, verbose=False ) if n_gpu_layers == -1: print("LLM loaded with GPU offload.") else: print("LLM loaded (CPU).") return llm except Exception as e: print(f"LLM init with n_gpu_layers={n_gpu_layers} failed: {e}") raise RuntimeError("Failed to initialize Llama model.") def load_models() -> Tuple[Xtts, Llama]: global tts_model, llm_model device = "cuda" if torch.cuda.is_available() else "cpu" if tts_model is None: tts_model = _load_xtts(device) if llm_model is None: llm_model = _load_llama() return tts_model, llm_model def generate_text_stream(llm_instance: Llama, prompt: str, history: List[Tuple[str, str | None]], system_message: str) -> Generator[str, None, None]: formatted_prompt = format_prompt_zephyr(prompt, history, system_message) stream = llm_instance( formatted_prompt, temperature=0.7, max_tokens=512, top_p=0.95, stop=LLM_STOP_WORDS, stream=True ) for response in stream: ch = response["choices"][0]["text"] # Guard against control tokens & isolated emoji artefacts try: is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch)) # emoji>=2.x except Exception: is_single_emoji = False if "<|user|>" in ch or is_single_emoji: continue yield ch def generate_audio_stream(tts_instance: Xtts, text: str, language: str, latents: Tuple[np.ndarray, np.ndarray]) -> Generator[bytes, None, None]: gpt_cond_latent, speaker_embedding = latents try: for chunk in tts_instance.inference_stream( text=text, language=language, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=0.85, ): if chunk is not None: yield chunk.detach().cpu().numpy().squeeze().tobytes() except RuntimeError as e: print(f"Error during TTS inference: {e}") # Soft-restart if GPU went bad and we can talk to the HF API if "device-side assert" in str(e) and api: gr.Warning("Critical GPU error. Attempting to restart the Space...") try: api.restart_space(repo_id=repo_id) except Exception: pass # =================================================================================== # 4. MAIN GRADIO FUNCTION (Decorated for ZeroGPU) # =================================================================================== @spaces.GPU(duration=120) # Request GPU for 120 seconds def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]: if secret_token_input != SECRET_TOKEN: raise gr.Error("Invalid secret token provided.") if not input_text: return [] # Load models tts, llm = load_models() # Pre-compute voice latents latent_map: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} for role, filename in [ ("Cloée", "cloee-1.wav"), ("Julian", "julian-bedtime-style-1.wav"), ("Pirate", "pirate_by_coqui.wav"), ("Thera", "thera-1.wav"), ]: path = os.path.join("voices", filename) latent_map[role] = tts.get_conditioning_latents( audio_path=path, gpt_cond_len=30, max_ref_length=60 ) # Generate story text history: List[Tuple[str, str | None]] = [(input_text, None)] full_story_text = "".join( generate_text_stream(llm, history[-1][0], history[:-1], system_message=ROLE_PROMPTS[chatbot_role]) ).strip() if not full_story_text: return [] # Tokenize into shorter sentences for TTS sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH) lang = langid.classify(sentences[0])[0] if sentences else "en" results: List[Dict[str, str]] = [] for sentence in sentences: if not any(c.isalnum() for c in sentence): continue audio_chunks = generate_audio_stream(tts, sentence, lang, latent_map[chatbot_role]) pcm_data = b"".join(chunk for chunk in audio_chunks if chunk) # Optional noise reduction (best-effort) try: data_s16 = np.frombuffer(pcm_data, dtype=np.int16) if data_s16.size > 0: float_data = data_s16.astype(np.float32) / 32767.0 reduced = nr.reduce_noise(y=float_data, sr=24000) final_pcm = (reduced * 32767).astype(np.int16).tobytes() else: final_pcm = pcm_data except Exception: final_pcm = pcm_data b64_wav = base64.b64encode(pcm_to_wav(final_pcm)).decode("utf-8") results.append({"text": sentence, "audio": b64_wav}) return results # =================================================================================== # 5. GRADIO INTERFACE LAUNCH # =================================================================================== # Download voice files on startup print("Downloading voice files...") file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"] base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/" os.makedirs("voices", exist_ok=True) for name in file_names: dst = os.path.join("voices", name) if not os.path.exists(dst): try: resp = requests.get(base_url + name, timeout=30) resp.raise_for_status() with open(dst, "wb") as f: f.write(resp.content) except Exception as e: print(f"Failed to download {name}: {e}") # Define the Gradio Interface demo = gr.Interface( fn=generate_story_and_speech, inputs=[ gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN), gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"), gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"), ], outputs=gr.JSON(label="Story and Audio Output"), title="AI Storyteller with ZeroGPU", description="Enter a prompt to generate a short story with voice narration using on-demand GPU.", allow_flagging="never", analytics_enabled=False, # <- keep analytics off to avoid pandas 2.x requirement ) if __name__ == "__main__": # For Spaces or Docker, these defaults are handy; adjust as needed. demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))