File size: 4,578 Bytes
d334bcd 851663b d334bcd 851663b 036ffff d334bcd 4fadd0f f79db70 d334bcd 914eb9e d334bcd fc72e9f b06930a d334bcd f79db70 d334bcd fc72e9f d334bcd f036d34 aaaab74 d334bcd 4fadd0f d334bcd 4fadd0f d334bcd f036d34 d334bcd 4fadd0f d334bcd fc72e9f d334bcd 4fadd0f b06930a 4fadd0f b06930a 4fadd0f d2e46d8 4fadd0f 17a0dd0 4fadd0f b06930a 4fadd0f b06930a 4fadd0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# ===============================
# FORCE CPU ONLY (VERY TOP)
# ===============================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TORCH_FORCE_CPU"] = "1"
import torch
# ---- HARD FORCE torch.load → CPU ----
_original_torch_load = torch.load
def cpu_only_torch_load(*args, **kwargs):
kwargs["map_location"] = torch.device("cpu")
return _original_torch_load(*args, **kwargs)
torch.load = cpu_only_torch_load
torch.cuda.is_available = lambda: False
# ===============================
# STANDARD IMPORTS
# ===============================
from fastapi import FastAPI
from pydantic import BaseModel
import base64
import numpy as np
import io
from scipy.io.wavfile import write as write_wav
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
# ===============================
# GLOBAL MODEL CACHE
# ===============================
MODEL = None
# ===============================
# MAX QUOTA (from ENV)
# ===============================
TTS_MAX_QUOTA = int(os.getenv("TTS_MAX_QUOTA", 10)) # default 10 requests/day
tts_usage = 0 # simple in-memory counter for demo
# ===============================
# MODEL LOADER
# ===============================
def get_or_load_model():
global MODEL
if MODEL is None:
print("🔄 Loading ChatterboxMultilingualTTS (CPU ONLY)")
MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
print("✅ Model loaded on CPU")
return MODEL
# ===============================
# SINGING FORMATTER
# ===============================
def format_for_singing(lyrics: str) -> str:
lines = []
for line in lyrics.splitlines():
line = line.strip()
if not line:
continue
# Stretch vowels lightly
line = (
line.replace("a", "aa")
.replace("e", "ee")
.replace("i", "ii")
.replace("o", "oo")
.replace("u", "uu")
)
lines.append(f"{line} ♪ ...")
return "\n".join(lines)
# ===============================
# FASTAPI APP + LIFESPAN
# ===============================
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# Warmup on startup
get_or_load_model()
yield
# No shutdown logic needed
app = FastAPI(lifespan=lifespan)
# ===============================
# HEALTH CHECK
# ===============================
@app.get("/health")
def health():
return {
"status": "ok",
"device": "cpu",
"cuda_available": torch.cuda.is_available()
}
# ===============================
# QUOTA INFO
# ===============================
@app.get("/quota")
def get_quota():
return {
"used": tts_usage,
"limit": TTS_MAX_QUOTA,
"remaining": max(0, TTS_MAX_QUOTA - tts_usage)
}
# ===============================
# TTS INPUT SCHEMA
# ===============================
class TTSPayload(BaseModel):
text: str
language_id: str = "en"
mode: str = "Speak 🗣️" # or "Sing 🎵"
# ===============================
# TTS ENDPOINT
# ===============================
@app.post("/tts")
def generate_tts(payload: TTSPayload):
global tts_usage
if tts_usage >= TTS_MAX_QUOTA:
return {
"error": "Quota exceeded",
"message": f"Daily limit of {TTS_MAX_QUOTA} TTS requests reached. Try again tomorrow."
}
model = get_or_load_model()
# Determine final text
if payload.mode == "Sing 🎵":
if not payload.text.strip():
return {"error": "Lyrics required for Sing mode."}
final_text = format_for_singing(payload.text)
else:
if not payload.text.strip():
return {"error": "Text required for Speak mode."}
final_text = payload.text
# CPU-safe inference
with torch.no_grad():
wav = model.generate(
final_text[:300],
language_id=payload.language_id,
)
# convert tensor → numpy
wav = wav.squeeze(0).detach().cpu().numpy()
sr = model.sr
# Convert numpy -> WAV bytes
buf = io.BytesIO()
write_wav(buf, sr, wav.astype(np.float32))
buf.seek(0)
audio_bytes = buf.read()
# Increment quota usage
tts_usage += 1
# Return as base64
return {
"sr": sr,
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"quota_used": tts_usage,
"quota_limit": TTS_MAX_QUOTA
}
# ===============================
# RUN: uvicorn app:app --host 0.0.0.0 --port 7860
# ===============================
|