rahul7star's picture
Update app.py
b06930a verified
# ===============================
# FORCE CPU ONLY (VERY TOP)
# ===============================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TORCH_FORCE_CPU"] = "1"
import torch
# ---- HARD FORCE torch.load β†’ CPU ----
_original_torch_load = torch.load
def cpu_only_torch_load(*args, **kwargs):
kwargs["map_location"] = torch.device("cpu")
return _original_torch_load(*args, **kwargs)
torch.load = cpu_only_torch_load
torch.cuda.is_available = lambda: False
# ===============================
# STANDARD IMPORTS
# ===============================
from fastapi import FastAPI
from pydantic import BaseModel
import base64
import numpy as np
import io
from scipy.io.wavfile import write as write_wav
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
# ===============================
# GLOBAL MODEL CACHE
# ===============================
MODEL = None
# ===============================
# MAX QUOTA (from ENV)
# ===============================
TTS_MAX_QUOTA = int(os.getenv("TTS_MAX_QUOTA", 10)) # default 10 requests/day
tts_usage = 0 # simple in-memory counter for demo
# ===============================
# MODEL LOADER
# ===============================
def get_or_load_model():
global MODEL
if MODEL is None:
print("πŸ”„ Loading ChatterboxMultilingualTTS (CPU ONLY)")
MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
print("βœ… Model loaded on CPU")
return MODEL
# ===============================
# SINGING FORMATTER
# ===============================
def format_for_singing(lyrics: str) -> str:
lines = []
for line in lyrics.splitlines():
line = line.strip()
if not line:
continue
# Stretch vowels lightly
line = (
line.replace("a", "aa")
.replace("e", "ee")
.replace("i", "ii")
.replace("o", "oo")
.replace("u", "uu")
)
lines.append(f"{line} β™ͺ ...")
return "\n".join(lines)
# ===============================
# FASTAPI APP + LIFESPAN
# ===============================
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# Warmup on startup
get_or_load_model()
yield
# No shutdown logic needed
app = FastAPI(lifespan=lifespan)
# ===============================
# HEALTH CHECK
# ===============================
@app.get("/health")
def health():
return {
"status": "ok",
"device": "cpu",
"cuda_available": torch.cuda.is_available()
}
# ===============================
# QUOTA INFO
# ===============================
@app.get("/quota")
def get_quota():
return {
"used": tts_usage,
"limit": TTS_MAX_QUOTA,
"remaining": max(0, TTS_MAX_QUOTA - tts_usage)
}
# ===============================
# TTS INPUT SCHEMA
# ===============================
class TTSPayload(BaseModel):
text: str
language_id: str = "en"
mode: str = "Speak πŸ—£οΈ" # or "Sing 🎡"
# ===============================
# TTS ENDPOINT
# ===============================
@app.post("/tts")
def generate_tts(payload: TTSPayload):
global tts_usage
if tts_usage >= TTS_MAX_QUOTA:
return {
"error": "Quota exceeded",
"message": f"Daily limit of {TTS_MAX_QUOTA} TTS requests reached. Try again tomorrow."
}
model = get_or_load_model()
# Determine final text
if payload.mode == "Sing 🎡":
if not payload.text.strip():
return {"error": "Lyrics required for Sing mode."}
final_text = format_for_singing(payload.text)
else:
if not payload.text.strip():
return {"error": "Text required for Speak mode."}
final_text = payload.text
# CPU-safe inference
with torch.no_grad():
wav = model.generate(
final_text[:300],
language_id=payload.language_id,
)
# convert tensor β†’ numpy
wav = wav.squeeze(0).detach().cpu().numpy()
sr = model.sr
# Convert numpy -> WAV bytes
buf = io.BytesIO()
write_wav(buf, sr, wav.astype(np.float32))
buf.seek(0)
audio_bytes = buf.read()
# Increment quota usage
tts_usage += 1
# Return as base64
return {
"sr": sr,
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"quota_used": tts_usage,
"quota_limit": TTS_MAX_QUOTA
}
# ===============================
# RUN: uvicorn app:app --host 0.0.0.0 --port 7860
# ===============================