#!/usr/bin/env python3 # /// script # requires-python = ">=3.11,<3.12" # dependencies = [ # "fastapi>=0.100.0", # "uvicorn[standard]>=0.20.0", # "pydantic>=2.0.0", # "httpx>=0.25.0", # "typer>=0.9.0", # "numpy>=1.24,<1.26", # "torch", # "torchaudio", # "peft", # "chatterbox-tts @ git+https://github.com/abidlabs/chatterbox", # ] # /// """ Chatterbox TTS Model Server Compatible with HuggingFace InferenceClient text_to_speech API """ import argparse import os import tempfile from typing import Optional, Dict, Any import uvicorn import torch import torchaudio as ta from fastapi import FastAPI, HTTPException from fastapi.responses import Response from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from chatterbox.tts import ChatterboxTTS class TTSRequest(BaseModel): inputs: str # text to synthesize parameters: Optional[Dict[str, Any]] = None extra_body: Optional[Dict[str, Any]] = None app = FastAPI(title="Chatterbox TTS Server", version="1.0.0") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model instance model = None @app.get("/") async def health_check(): return {"status": "ok", "model": "chatterbox"} @app.post("/api/v1/") async def text_to_speech(request: TTSRequest): """ Text-to-speech endpoint compatible with HuggingFace InferenceClient Generates speech using Chatterbox TTS model """ global model if model is None: raise HTTPException( status_code=503, detail="Model not loaded. Please wait for server to initialize.", ) text = request.inputs extra_body = request.extra_body or {} print(f"TTS Request - Text: '{text[:50]}...' Extra body: {extra_body}") try: # Get audio prompt from extra_body if provided audio_prompt_path = extra_body.get("audio_url") # Generate speech if audio_prompt_path: # Use voice cloning with audio prompt wav = model.generate(text, audio_prompt_path=audio_prompt_path) else: # Use default voice wav = model.generate(text) # Convert tensor to bytes with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: ta.save(tmp_file.name, wav, model.sr) tmp_file.flush() # Read the saved audio file as bytes with open(tmp_file.name, "rb") as f: audio_data = f.read() # Clean up temp file os.unlink(tmp_file.name) return Response(content=audio_data, media_type="audio/wav") except Exception as e: print(f"Error generating TTS: {str(e)}") raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") def auto_detect_device(): """Auto-detect the best available device""" if torch.cuda.is_available(): print("🚀 CUDA detected, using GPU acceleration") return "cuda" else: print("💻 CUDA not available, using CPU") return "cpu" def load_model(device=None): """Load the Chatterbox TTS model with automatic device detection""" global model if device is None: device = auto_detect_device() model = ChatterboxTTS.from_pretrained(device=device) def main(): global model parser = argparse.ArgumentParser(description="Start Chatterbox TTS Server") parser.add_argument( "--port", "-p", type=int, default=7861, help="Port to run server on" ) parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") args = parser.parse_args() print(f"🎙️ Starting Chatterbox TTS Server on {args.host}:{args.port}") print(f"🌐 Health check: http://localhost:{args.port}/") print(f"🔊 TTS endpoint: http://localhost:{args.port}/api/v1/") load_model() uvicorn.run(app, host=args.host, port=args.port, log_level="info") if __name__ == "__main__": main()