|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
os.environ["TORCH_FORCE_CPU"] = "1" |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
_original_torch_load = torch.load |
|
|
|
|
|
def cpu_only_torch_load(*args, **kwargs): |
|
|
kwargs["map_location"] = torch.device("cpu") |
|
|
return _original_torch_load(*args, **kwargs) |
|
|
|
|
|
torch.load = cpu_only_torch_load |
|
|
torch.cuda.is_available = lambda: False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
import base64 |
|
|
import numpy as np |
|
|
import io |
|
|
from scipy.io.wavfile import write as write_wav |
|
|
|
|
|
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TTS_MAX_QUOTA = int(os.getenv("TTS_MAX_QUOTA", 10)) |
|
|
tts_usage = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_or_load_model(): |
|
|
global MODEL |
|
|
if MODEL is None: |
|
|
print("π Loading ChatterboxMultilingualTTS (CPU ONLY)") |
|
|
MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu") |
|
|
print("β
Model loaded on CPU") |
|
|
return MODEL |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_for_singing(lyrics: str) -> str: |
|
|
lines = [] |
|
|
for line in lyrics.splitlines(): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
line = ( |
|
|
line.replace("a", "aa") |
|
|
.replace("e", "ee") |
|
|
.replace("i", "ii") |
|
|
.replace("o", "oo") |
|
|
.replace("u", "uu") |
|
|
) |
|
|
lines.append(f"{line} βͺ ...") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
get_or_load_model() |
|
|
yield |
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"device": "cpu", |
|
|
"cuda_available": torch.cuda.is_available() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/quota") |
|
|
def get_quota(): |
|
|
return { |
|
|
"used": tts_usage, |
|
|
"limit": TTS_MAX_QUOTA, |
|
|
"remaining": max(0, TTS_MAX_QUOTA - tts_usage) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TTSPayload(BaseModel): |
|
|
text: str |
|
|
language_id: str = "en" |
|
|
mode: str = "Speak π£οΈ" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/tts") |
|
|
def generate_tts(payload: TTSPayload): |
|
|
global tts_usage |
|
|
if tts_usage >= TTS_MAX_QUOTA: |
|
|
return { |
|
|
"error": "Quota exceeded", |
|
|
"message": f"Daily limit of {TTS_MAX_QUOTA} TTS requests reached. Try again tomorrow." |
|
|
} |
|
|
|
|
|
model = get_or_load_model() |
|
|
|
|
|
|
|
|
if payload.mode == "Sing π΅": |
|
|
if not payload.text.strip(): |
|
|
return {"error": "Lyrics required for Sing mode."} |
|
|
final_text = format_for_singing(payload.text) |
|
|
else: |
|
|
if not payload.text.strip(): |
|
|
return {"error": "Text required for Speak mode."} |
|
|
final_text = payload.text |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
wav = model.generate( |
|
|
final_text[:300], |
|
|
language_id=payload.language_id, |
|
|
) |
|
|
|
|
|
wav = wav.squeeze(0).detach().cpu().numpy() |
|
|
sr = model.sr |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
write_wav(buf, sr, wav.astype(np.float32)) |
|
|
buf.seek(0) |
|
|
audio_bytes = buf.read() |
|
|
|
|
|
|
|
|
tts_usage += 1 |
|
|
|
|
|
|
|
|
return { |
|
|
"sr": sr, |
|
|
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), |
|
|
"quota_used": tts_usage, |
|
|
"quota_limit": TTS_MAX_QUOTA |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|