# =============================== # FORCE CPU ONLY (VERY TOP) # =============================== import os os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["TORCH_FORCE_CPU"] = "1" import torch # ---- HARD FORCE torch.load → CPU ---- _original_torch_load = torch.load def cpu_only_torch_load(*args, **kwargs): kwargs["map_location"] = torch.device("cpu") return _original_torch_load(*args, **kwargs) torch.load = cpu_only_torch_load torch.cuda.is_available = lambda: False # =============================== # STANDARD IMPORTS # =============================== from fastapi import FastAPI from pydantic import BaseModel import base64 import numpy as np import io from scipy.io.wavfile import write as write_wav from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS # =============================== # GLOBAL MODEL CACHE # =============================== MODEL = None # =============================== # MAX QUOTA (from ENV) # =============================== TTS_MAX_QUOTA = int(os.getenv("TTS_MAX_QUOTA", 10)) # default 10 requests/day tts_usage = 0 # simple in-memory counter for demo # =============================== # MODEL LOADER # =============================== def get_or_load_model(): global MODEL if MODEL is None: print("🔄 Loading ChatterboxMultilingualTTS (CPU ONLY)") MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu") print("✅ Model loaded on CPU") return MODEL # =============================== # SINGING FORMATTER # =============================== def format_for_singing(lyrics: str) -> str: lines = [] for line in lyrics.splitlines(): line = line.strip() if not line: continue # Stretch vowels lightly line = ( line.replace("a", "aa") .replace("e", "ee") .replace("i", "ii") .replace("o", "oo") .replace("u", "uu") ) lines.append(f"{line} ♪ ...") return "\n".join(lines) # =============================== # FASTAPI APP + LIFESPAN # =============================== from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: FastAPI): # Warmup on startup get_or_load_model() yield # No shutdown logic needed app = FastAPI(lifespan=lifespan) # =============================== # HEALTH CHECK # =============================== @app.get("/health") def health(): return { "status": "ok", "device": "cpu", "cuda_available": torch.cuda.is_available() } # =============================== # QUOTA INFO # =============================== @app.get("/quota") def get_quota(): return { "used": tts_usage, "limit": TTS_MAX_QUOTA, "remaining": max(0, TTS_MAX_QUOTA - tts_usage) } # =============================== # TTS INPUT SCHEMA # =============================== class TTSPayload(BaseModel): text: str language_id: str = "en" mode: str = "Speak 🗣️" # or "Sing 🎵" # =============================== # TTS ENDPOINT # =============================== @app.post("/tts") def generate_tts(payload: TTSPayload): global tts_usage if tts_usage >= TTS_MAX_QUOTA: return { "error": "Quota exceeded", "message": f"Daily limit of {TTS_MAX_QUOTA} TTS requests reached. Try again tomorrow." } model = get_or_load_model() # Determine final text if payload.mode == "Sing 🎵": if not payload.text.strip(): return {"error": "Lyrics required for Sing mode."} final_text = format_for_singing(payload.text) else: if not payload.text.strip(): return {"error": "Text required for Speak mode."} final_text = payload.text # CPU-safe inference with torch.no_grad(): wav = model.generate( final_text[:300], language_id=payload.language_id, ) # convert tensor → numpy wav = wav.squeeze(0).detach().cpu().numpy() sr = model.sr # Convert numpy -> WAV bytes buf = io.BytesIO() write_wav(buf, sr, wav.astype(np.float32)) buf.seek(0) audio_bytes = buf.read() # Increment quota usage tts_usage += 1 # Return as base64 return { "sr": sr, "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), "quota_used": tts_usage, "quota_limit": TTS_MAX_QUOTA } # =============================== # RUN: uvicorn app:app --host 0.0.0.0 --port 7860 # ===============================