Spaces:
Running
on
Zero
Running
on
Zero
| # =================================================================================== | |
| # 1) SETUP & IMPORTS | |
| # =================================================================================== | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import re | |
| import base64 | |
| import struct | |
| import textwrap | |
| import requests | |
| import atexit | |
| import inspect | |
| from typing import List, Dict, Tuple, Generator, Any | |
| # --- Fast, safe defaults --- | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| os.environ.setdefault("COQUI_TOS_AGREED", "1") | |
| os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false") # truly disable analytics | |
| os.environ.setdefault("TORCHAUDIO_USE_FFMPEG", "0") # avoid torchaudio/ffmpeg linkage quirks | |
| # --- .env early (HF_TOKEN / SECRET_TOKEN) --- | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # --- NumPy sanity (Torch 2.2.x prefers NumPy 1.x) --- | |
| import numpy as _np | |
| if int(_np.__version__.split(".", 1)[0]) >= 2: | |
| raise RuntimeError( | |
| f"Detected numpy=={_np.__version__}. Please ensure numpy<2 (e.g., 1.26.4)." | |
| ) | |
| # --- Pandas compat shims (Gradio & mixed versions) --- | |
| try: | |
| import pandas as pd | |
| from pandas._config.config import register_option | |
| # Option shim | |
| try: | |
| pd.get_option("future.no_silent_downcasting") | |
| except Exception: | |
| register_option("future.no_silent_downcasting", False, validator=None, doc="compat shim for Gradio") | |
| # infer_objects(copy=...) shim (older pandas lacks 'copy' kwarg) | |
| if hasattr(pd, "DataFrame"): | |
| try: | |
| sig = inspect.signature(pd.DataFrame.infer_objects) | |
| if "copy" not in sig.parameters: | |
| _orig_infer_objects = pd.DataFrame.infer_objects | |
| def _infer_objects_compat(self, *args, **kwargs): | |
| kwargs.pop("copy", None) | |
| return _orig_infer_objects(self, *args, **kwargs) | |
| pd.DataFrame.infer_objects = _infer_objects_compat | |
| except Exception: | |
| pass | |
| except Exception: | |
| pd = None # ok if pandas is unavailable | |
| # --- Hugging Face Spaces & ZeroGPU (import BEFORE CUDA libs) --- | |
| try: | |
| import spaces # Required for ZeroGPU on HF | |
| except Exception: | |
| class _SpacesShim: | |
| def GPU(self, *args, **kwargs): | |
| def _wrap(fn): | |
| return fn | |
| return _wrap | |
| spaces = _SpacesShim() | |
| import gradio as gr | |
| # --- Core ML & Data Libraries (after spaces import) --- | |
| import torch | |
| import numpy as np | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from llama_cpp import Llama | |
| # --- Audio decoding (pure ffmpeg-python; no torchaudio) --- | |
| import ffmpeg | |
| # --- TTS Libraries --- | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| from TTS.utils.manage import ModelManager | |
| import TTS.tts.models.xtts as xtts_module # for monkey-patching load_audio | |
| # --- Text & Audio Processing --- | |
| import nltk | |
| import langid | |
| import emoji | |
| import noisereduce as nr | |
| # =================================================================================== | |
| # 2) GLOBALS & HELPERS | |
| # =================================================================================== | |
| def _ensure_nltk() -> None: | |
| # Newer NLTK splits data into 'punkt' and 'punkt_tab' | |
| for pkg in ("punkt", "punkt_tab"): | |
| try: | |
| if pkg == "punkt": | |
| nltk.data.find("tokenizers/punkt") | |
| else: | |
| nltk.data.find("tokenizers/punkt_tab") | |
| except LookupError: | |
| nltk.download(pkg, quiet=True) | |
| _ensure_nltk() | |
| # Models & caches | |
| tts_model: Xtts | None = None | |
| llm_model: Llama | None = None | |
| # Store latents as NumPy on CPU for portability; convert to device at inference time | |
| voice_latents: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} | |
| # Config | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| api = HfApi(token=HF_TOKEN) if HF_TOKEN else None | |
| repo_id = "ruslanmv/ai-story-server" | |
| SECRET_TOKEN = os.getenv("SECRET_TOKEN", "secret") | |
| SENTENCE_SPLIT_LENGTH = 250 | |
| LLM_STOP_WORDS = ["</s>", "<|user|>", "/s>"] | |
| # System prompts and roles | |
| default_system_message = ( | |
| "You're a storyteller crafting a short tale for young listeners. Keep sentences short and simple. " | |
| "Use narrative style only, without lists or complex words. Type numbers as words (e.g., 'ten')." | |
| ) | |
| system_message = os.environ.get("SYSTEM_MESSAGE", default_system_message) | |
| ROLES = ["Cloée", "Julian", "Pirate", "Thera"] | |
| ROLE_PROMPTS = {role: system_message for role in ROLES} | |
| ROLE_PROMPTS["Pirate"] = ( | |
| "You are AI Beard, a pirate. Craft your response from his first-person perspective. " | |
| "Keep answers short, as if in a real conversation. Only provide the words AI Beard would speak." | |
| ) | |
| # ---------- tiny utilities ---------- | |
| def _model_device(m: torch.nn.Module) -> torch.device: | |
| try: | |
| return next(m.parameters()).device | |
| except StopIteration: | |
| return torch.device("cpu") | |
| def _to_device_float_tensor(x: Any, device: torch.device) -> torch.Tensor: | |
| if isinstance(x, np.ndarray): | |
| return torch.from_numpy(x).float().to(device) | |
| if torch.is_tensor(x): | |
| return x.to(device, dtype=torch.float32) | |
| return torch.as_tensor(x, dtype=torch.float32, device=device) | |
| def _latents_for_device(latents: Tuple[Any, Any], device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: | |
| gpt_cond, spk = latents | |
| return _to_device_float_tensor(gpt_cond, device), _to_device_float_tensor(spk, device) | |
| def pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, channels: int = 1, bit_depth: int = 16) -> bytes: | |
| if pcm_data.startswith(b"RIFF"): | |
| return pcm_data | |
| byte_rate = sample_rate * channels * bit_depth // 8 | |
| block_align = channels * bit_depth // 8 | |
| chunk_size = 36 + len(pcm_data) | |
| header = struct.pack( | |
| "<4sI4s4sIHHIIHH4sI", | |
| b"RIFF", chunk_size, b"WAVE", b"fmt ", | |
| 16, 1, channels, sample_rate, byte_rate, block_align, bit_depth, | |
| b"data", len(pcm_data) | |
| ) | |
| return header + pcm_data | |
| def split_sentences(text: str, max_len: int) -> List[str]: | |
| # Try NLTK; if it fails for any reason, fallback to a simple regex splitter. | |
| try: | |
| sentences = nltk.sent_tokenize(text) | |
| except Exception: | |
| sentences = re.split(r"(?<=[\.\!\?])\s+", text) | |
| chunks: List[str] = [] | |
| for sent in sentences: | |
| if len(sent) > max_len: | |
| chunks.extend(textwrap.wrap(sent, max_len, break_long_words=True)) | |
| else: | |
| if sent: | |
| chunks.append(sent) | |
| return chunks | |
| def format_prompt_zephyr(message: str, history: List[Tuple[str, str | None]], system_message: str) -> str: | |
| prompt = f"<|system|>\n{system_message}</s>" | |
| for user_prompt, bot_response in history: | |
| if bot_response: | |
| prompt += f"<|user|>\n{user_prompt}</s><|assistant|>\n{bot_response}</s>" | |
| prompt += f"<|user|>\n{message}</s><|assistant|>" | |
| return prompt | |
| # ---------- robust audio decode (mono via ffmpeg) ---------- | |
| def _decode_audio_ffmpeg_to_mono(path: str, target_sr: int) -> np.ndarray: | |
| """ | |
| Return float32 waveform in [-1, 1], mono, resampled to target_sr. | |
| Shape: (samples,) | |
| """ | |
| try: | |
| out, _ = ( | |
| ffmpeg | |
| .input(path) | |
| .output("pipe:", format="s16le", acodec="pcm_s16le", ac=1, ar=target_sr) | |
| .run(capture_stdout=True, capture_stderr=True, cmd="ffmpeg") | |
| ) | |
| pcm = np.frombuffer(out, dtype=np.int16) | |
| if pcm.size == 0: | |
| raise RuntimeError("ffmpeg produced empty audio.") | |
| wav = (pcm.astype(np.float32) / 32767.0) | |
| return wav | |
| except ffmpeg.Error as e: | |
| raise RuntimeError(f"ffmpeg decode failed: {e.stderr.decode(errors='ignore') if e.stderr else e}") from e | |
| # ---------- monkey-patch XTTS internal loader to avoid torchaudio/torio ---------- | |
| def _patched_load_audio(audiopath: str, load_sr: int): | |
| """ | |
| Match XTTS' expected return type: | |
| - returns a torch.FloatTensor shaped [1, samples], normalized to [-1, 1], | |
| already resampled to `load_sr`. | |
| - DO NOT return (audio, sr) tuple. | |
| """ | |
| wav = _decode_audio_ffmpeg_to_mono(audiopath, target_sr=load_sr) | |
| import torch as _torch # local import to avoid circularities | |
| audio = _torch.from_numpy(wav).float().unsqueeze(0) # [1, N] on CPU | |
| return audio | |
| xtts_module.load_audio = _patched_load_audio | |
| try: | |
| import TTS.utils.audio as _tts_audio_mod | |
| _tts_audio_mod.load_audio = _patched_load_audio | |
| except Exception: | |
| pass | |
| def _coqui_cache_dir() -> str: | |
| # Matches what TTS uses on Linux: ~/.local/share/tts | |
| return os.path.join(os.path.expanduser("~"), ".local", "share", "tts") | |
| # =================================================================================== | |
| # 3) PRECACHE & MODEL LOADERS (CPU at startup to avoid ZeroGPU issues) | |
| # =================================================================================== | |
| def precache_assets() -> None: | |
| """Download voice WAVs, XTTS weights, and Zephyr GGUF to local cache before any inference.""" | |
| print("Pre-caching voice files...") | |
| file_names = ["cloee-1.wav", "julian-bedtime-style-1.wav", "pirate_by_coqui.wav", "thera-1.wav"] | |
| base_url = "https://raw.githubusercontent.com/ruslanmv/ai-story-server/main/voices/" | |
| os.makedirs("voices", exist_ok=True) | |
| for name in file_names: | |
| dst = os.path.join("voices", name) | |
| if not os.path.exists(dst): | |
| try: | |
| resp = requests.get(base_url + name, timeout=30) | |
| resp.raise_for_status() | |
| with open(dst, "wb") as f: | |
| f.write(resp.content) | |
| except Exception as e: | |
| print(f"Failed to download {name}: {e}") | |
| print("Pre-caching XTTS v2 model files...") | |
| ModelManager().download_model("tts_models/multilingual/multi-dataset/xtts_v2") | |
| print("Pre-caching Zephyr GGUF...") | |
| try: | |
| hf_hub_download( | |
| repo_id="TheBloke/zephyr-7B-beta-GGUF", | |
| filename="zephyr-7b-beta.Q5_K_M.gguf", | |
| force_download=False | |
| ) | |
| except Exception as e: | |
| print(f"Warning: GGUF pre-cache error: {e}") | |
| def _load_xtts(device: str = "cpu") -> Xtts: | |
| """Load XTTS from the local cache. Keep CPU at startup to avoid ZeroGPU device mixups.""" | |
| print(f"Loading Coqui XTTS V2 model on {device.upper()}...") | |
| model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
| ModelManager().download_model(model_name) # idempotent | |
| model_dir = os.path.join(_coqui_cache_dir(), model_name.replace("/", "--")) | |
| cfg = XttsConfig() | |
| cfg.load_json(os.path.join(model_dir, "config.json")) | |
| model = Xtts.init_from_config(cfg) | |
| model.load_checkpoint( | |
| cfg, | |
| checkpoint_dir=model_dir, | |
| eval=True, | |
| use_deepspeed=False, | |
| ) | |
| model.to(device) | |
| print("XTTS model loaded.") | |
| return model | |
| def _load_llama() -> Llama: | |
| """ | |
| Load Llama (Zephyr GGUF). | |
| Keep simple & robust: default to CPU (works everywhere). | |
| """ | |
| print("Loading LLM (Zephyr GGUF)...") | |
| zephyr_model_path = hf_hub_download( | |
| repo_id="TheBloke/zephyr-7B-beta-GGUF", | |
| filename="zephyr-7b-beta.Q5_K_M.gguf" | |
| ) | |
| llm = Llama( | |
| model_path=zephyr_model_path, | |
| n_gpu_layers=0, # CPU-only for reliability across Spaces/ZeroGPU | |
| n_ctx=4096, | |
| n_batch=512, | |
| verbose=False | |
| ) | |
| print("LLM loaded (CPU).") | |
| return llm | |
| def init_models_and_latents() -> None: | |
| """ | |
| Preload models on CPU and compute voice latents on CPU. | |
| This avoids ZeroGPU's "mixed device" errors from torchaudio-based resampling. | |
| """ | |
| global tts_model, llm_model, voice_latents | |
| if tts_model is None: | |
| tts_model = _load_xtts(device="cpu") # always CPU at startup | |
| if llm_model is None: | |
| llm_model = _load_llama() | |
| if not voice_latents: | |
| print("Computing voice conditioning latents (CPU)...") | |
| # Ensure the TTS model is on CPU while computing latents | |
| orig_dev = _model_device(tts_model) | |
| if orig_dev.type != "cpu": | |
| tts_model.to("cpu") | |
| with torch.no_grad(): | |
| for role, filename in [ | |
| ("Cloée", "cloee-1.wav"), | |
| ("Julian", "julian-bedtime-style-1.wav"), | |
| ("Pirate", "pirate_by_coqui.wav"), | |
| ("Thera", "thera-1.wav"), | |
| ]: | |
| path = os.path.join("voices", filename) | |
| gpt_lat, spk_emb = tts_model.get_conditioning_latents( | |
| audio_path=path, gpt_cond_len=30, max_ref_length=60 | |
| ) | |
| # Store as NumPy on CPU; convert to device on demand later | |
| voice_latents[role] = ( | |
| gpt_lat.detach().cpu().numpy(), | |
| spk_emb.detach().cpu().numpy(), | |
| ) | |
| # Return model to original device (keep CPU at startup for safety) | |
| if orig_dev.type != "cpu": | |
| tts_model.to(orig_dev) | |
| print("Voice latents ready.") | |
| # Ensure we close Llama cleanly to avoid __del__ issues at interpreter shutdown | |
| def _close_llm(): | |
| global llm_model | |
| try: | |
| if llm_model is not None: | |
| llm_model.close() | |
| except Exception: | |
| pass | |
| atexit.register(_close_llm) | |
| # =================================================================================== | |
| # 4) INFERENCE HELPERS | |
| # =================================================================================== | |
| def generate_text_stream(llm_instance: Llama, prompt: str, | |
| history: List[Tuple[str, str | None]], | |
| system_message_text: str) -> Generator[str, None, None]: | |
| formatted_prompt = format_prompt_zephyr(prompt, history, system_message_text) | |
| stream = llm_instance( | |
| formatted_prompt, | |
| temperature=0.7, | |
| max_tokens=512, | |
| top_p=0.95, | |
| stop=LLM_STOP_WORDS, | |
| stream=True | |
| ) | |
| for response in stream: | |
| ch = response["choices"][0]["text"] | |
| try: | |
| is_single_emoji = (len(ch) == 1 and emoji.is_emoji(ch)) | |
| except Exception: | |
| is_single_emoji = False | |
| if "<|user|>" in ch or is_single_emoji: | |
| continue | |
| yield ch | |
| def _extract_waveform_to_numpy(wav_any: Any) -> np.ndarray: | |
| """ | |
| Normalize various XTTS inference() return shapes/types to 1-D float32 numpy in [-1, 1]. | |
| """ | |
| if isinstance(wav_any, dict): | |
| for k in ("wav", "audio", "samples", "waveform"): | |
| if k in wav_any: | |
| wav_any = wav_any[k] | |
| break | |
| if torch.is_tensor(wav_any): | |
| arr = wav_any.detach().cpu().numpy() | |
| else: | |
| arr = np.asarray(wav_any) | |
| arr = np.squeeze(arr).astype(np.float32) | |
| # If not already normalized, attempt to scale if max > 1 (heuristic) | |
| maxabs = np.max(np.abs(arr)) if arr.size else 1.0 | |
| if maxabs > 1.5: # likely int16 or higher-amplitude float | |
| arr = arr / 32767.0 | |
| arr = np.clip(arr, -1.0, 1.0) | |
| return arr | |
| def synthesize_sentence_pcm16(tts_instance: Xtts, text: str, language: str, | |
| latents: Tuple[np.ndarray, np.ndarray]) -> bytes: | |
| """ | |
| Use non-streaming XTTS inference() to avoid GPT2InferenceModel streaming bug. | |
| Returns PCM16 bytes at 24 kHz mono. | |
| """ | |
| device = _model_device(tts_instance) | |
| gpt_cond_latent_t, speaker_embedding_t = _latents_for_device(latents, device) | |
| with torch.no_grad(): | |
| out = tts_instance.inference( | |
| text=text, | |
| language=language, | |
| gpt_cond_latent=gpt_cond_latent_t, | |
| speaker_embedding=speaker_embedding_t, | |
| temperature=0.85, | |
| ) | |
| f32 = _extract_waveform_to_numpy(out) | |
| s16 = (np.clip(f32, -1.0, 1.0) * 32767.0).astype(np.int16) | |
| return s16.tobytes() | |
| # =================================================================================== | |
| # 5) ZERO-GPU ENTRYPOINT (safe on native GPU as well) | |
| # =================================================================================== | |
| # GPU ops happen inside when on ZeroGPU | |
| def generate_story_and_speech(secret_token_input: str, input_text: str, chatbot_role: str) -> List[Dict[str, str]]: | |
| if secret_token_input != SECRET_TOKEN: | |
| raise gr.Error("Invalid secret token provided.") | |
| if not input_text: | |
| return [] | |
| # Ensure models/latents exist (loaded on CPU) | |
| if tts_model is None or llm_model is None or not voice_latents: | |
| init_models_and_latents() | |
| # During the GPU window, move XTTS to CUDA if available; otherwise stay on CPU | |
| try: | |
| if torch.cuda.is_available(): | |
| tts_model.to("cuda") | |
| else: | |
| tts_model.to("cpu") | |
| except Exception: | |
| tts_model.to("cpu") | |
| # Generate story text (LLM kept CPU for simplicity & reliability) | |
| history: List[Tuple[str, str | None]] = [(input_text, None)] | |
| full_story_text = "".join( | |
| generate_text_stream(llm_model, history[-1][0], history[:-1], system_message_text=ROLE_PROMPTS[chatbot_role]) | |
| ).strip() | |
| if not full_story_text: | |
| return [] | |
| # Split into TTS-friendly sentences | |
| sentences = split_sentences(full_story_text, SENTENCE_SPLIT_LENGTH) | |
| lang = langid.classify(sentences[0])[0] if sentences else "en" | |
| results: List[Dict[str, str]] = [] | |
| for sentence in sentences: | |
| if not any(c.isalnum() for c in sentence): | |
| continue | |
| # Synthesize whole sentence (non-stream) to avoid streaming bug | |
| pcm_data = synthesize_sentence_pcm16(tts_model, sentence, lang, voice_latents[chatbot_role]) | |
| # Optional noise reduction (best-effort) | |
| try: | |
| data_s16 = np.frombuffer(pcm_data, dtype=np.int16) | |
| if data_s16.size > 0: | |
| float_data = (data_s16.astype(np.float32) / 32767.0) | |
| reduced = nr.reduce_noise(y=float_data, sr=24000) | |
| final_pcm = np.clip(reduced * 32767.0, -32768, 32767).astype(np.int16).tobytes() | |
| else: | |
| final_pcm = pcm_data | |
| except Exception: | |
| final_pcm = pcm_data | |
| b64_wav = base64.b64encode(pcm_to_wav(final_pcm, sample_rate=24000, channels=1, bit_depth=16)).decode("utf-8") | |
| results.append({"text": sentence, "audio": b64_wav}) | |
| # Leave model on CPU after the ZeroGPU window | |
| try: | |
| tts_model.to("cpu") | |
| except Exception: | |
| pass | |
| return results | |
| # =================================================================================== | |
| # 6) STARTUP: PRECACHE & UI | |
| # =================================================================================== | |
| def build_ui() -> gr.Interface: | |
| return gr.Interface( | |
| fn=generate_story_and_speech, | |
| inputs=[ | |
| gr.Textbox(label="Secret Token", type="password", value=SECRET_TOKEN), | |
| gr.Textbox(placeholder="What should the story be about?", label="Story Prompt"), | |
| gr.Dropdown(choices=ROLES, label="Select a Storyteller", value="Cloée"), | |
| ], | |
| outputs=gr.JSON(label="Story and Audio Output"), | |
| title="AI Storyteller with ZeroGPU", | |
| description="Enter a prompt to generate a short story with voice narration. Uses GPU only within the generation call when available.", | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| print("===== Startup: pre-cache assets and preload models (CPU) =====") | |
| print(f"Python: {sys.version.split()[0]} | Torch CUDA available: {torch.cuda.is_available()}") | |
| precache_assets() # 1) download everything to disk | |
| init_models_and_latents() # 2) load models on CPU + compute voice latents on CPU | |
| print("Models and assets ready. Launching UI...") | |
| demo = build_ui() | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| ssr_mode=False, # disable experimental SSR noise | |
| ) | |