import os import torch import torchaudio as ta from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.responses import FileResponse from pydantic import BaseModel from chatterbox.mtl_tts import ChatterboxMultilingualTTS import functools import uvicorn import asyncio # Patch torch.load for CPU if necessary (as in app.py) # torch.load = functools.partial(torch.load, map_location='cpu') app = FastAPI() # 1. Determine device dynamically device_map = "cuda" if torch.cuda.is_available() else "cpu" # Create a lock to ensure only one generation happens at a time (important for GPU) model_lock = asyncio.Lock() print(f"CUDA Available: {torch.cuda.is_available()}") print(f"Using device: {device_map} with name: {torch.cuda.get_device_name(torch.cuda.current_device())}") print("Loading TTS model...") # Using Multilingual model as requested tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map) # Optimize for T4 GPU using half-precision (FP16) # We use autocast during inference for the best balance of speed and stability if device_map == "cuda": print("GPU optimization: FP16 Autocast enabled.") print("Model loaded.") class TTSRequest(BaseModel): message: str language: str channelID: str username: str messageid: str def cleanup_file(filepath: str): """Deletes the file after it has been sent.""" try: if os.path.exists(filepath): os.remove(filepath) print(f"Deleted temporary file: {filepath}") except Exception as e: print(f"Error deleting file {filepath}: {e}") def generate_audio(req: TTSRequest) -> str: """Generates audio and returns the filename.""" os.makedirs("outputs", exist_ok=True) filename = os.path.join("outputs", f"{req.channelID}-{req.username}-{req.messageid}.wav") try: # Use autocast to automatically handle float16/float32 mixing # This prevents the "mat1 and mat2 must have the same dtype" error if device_map == "cuda": with torch.amp.autocast(device_type='cuda', dtype=torch.float16): audio_tensor = tts_model.generate(req.message, language_id=req.language) else: audio_tensor = tts_model.generate(req.message, language_id=req.language) ta.save(filename, audio_tensor, tts_model.sr) return filename except Exception as e: raise HTTPException(status_code=500, detail=f"TTS Generation failed: {str(e)}") @app.post("/tts") async def tts_endpoint(req: TTSRequest, background_tasks: BackgroundTasks): async with model_lock: filename = await asyncio.to_thread(generate_audio, req) background_tasks.add_task(cleanup_file, filename) return FileResponse(path=filename, filename=filename, media_type='audio/wav') @app.post("/stream") async def stream_endpoint(req: TTSRequest, background_tasks: BackgroundTasks): async with model_lock: filename = await asyncio.to_thread(generate_audio, req) background_tasks.add_task(cleanup_file, filename) # FileResponse handles streaming efficiently for large files return FileResponse(path=filename, media_type='audio/wav') @app.post("/test") async def test_endpoint(req: TTSRequest): async with model_lock: filename = await asyncio.to_thread(generate_audio, req) # For /test, we don't delete the file and just return "ok" return {"status": "ok", "filename": filename} if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)