CDOM201's picture
Mat1 and Mat2 fixes
e00ded0 verified
import os
import torch
import torchaudio as ta
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from pydantic import BaseModel
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
import functools
import uvicorn
import asyncio
# Patch torch.load for CPU if necessary (as in app.py)
# torch.load = functools.partial(torch.load, map_location='cpu')
app = FastAPI()
# 1. Determine device dynamically
device_map = "cuda" if torch.cuda.is_available() else "cpu"
# Create a lock to ensure only one generation happens at a time (important for GPU)
model_lock = asyncio.Lock()
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"Using device: {device_map} with name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print("Loading TTS model...")
# Using Multilingual model as requested
tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
# Optimize for T4 GPU using half-precision (FP16)
# We use autocast during inference for the best balance of speed and stability
if device_map == "cuda":
print("GPU optimization: FP16 Autocast enabled.")
print("Model loaded.")
class TTSRequest(BaseModel):
message: str
language: str
channelID: str
username: str
messageid: str
def cleanup_file(filepath: str):
"""Deletes the file after it has been sent."""
try:
if os.path.exists(filepath):
os.remove(filepath)
print(f"Deleted temporary file: {filepath}")
except Exception as e:
print(f"Error deleting file {filepath}: {e}")
def generate_audio(req: TTSRequest) -> str:
"""Generates audio and returns the filename."""
os.makedirs("outputs", exist_ok=True)
filename = os.path.join("outputs", f"{req.channelID}-{req.username}-{req.messageid}.wav")
try:
# Use autocast to automatically handle float16/float32 mixing
# This prevents the "mat1 and mat2 must have the same dtype" error
if device_map == "cuda":
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
audio_tensor = tts_model.generate(req.message, language_id=req.language)
else:
audio_tensor = tts_model.generate(req.message, language_id=req.language)
ta.save(filename, audio_tensor, tts_model.sr)
return filename
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS Generation failed: {str(e)}")
@app.post("/tts")
async def tts_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
async with model_lock:
filename = await asyncio.to_thread(generate_audio, req)
background_tasks.add_task(cleanup_file, filename)
return FileResponse(path=filename, filename=filename, media_type='audio/wav')
@app.post("/stream")
async def stream_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
async with model_lock:
filename = await asyncio.to_thread(generate_audio, req)
background_tasks.add_task(cleanup_file, filename)
# FileResponse handles streaming efficiently for large files
return FileResponse(path=filename, media_type='audio/wav')
@app.post("/test")
async def test_endpoint(req: TTSRequest):
async with model_lock:
filename = await asyncio.to_thread(generate_audio, req)
# For /test, we don't delete the file and just return "ok"
return {"status": "ok", "filename": filename}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)