GitHub Action
Sync from GitHub: bccc82bad575eac234082575e96f9a0d945df1cc
5a78f5f
raw
history blame
4.1 kB
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "fastapi>=0.100.0",
# "uvicorn[standard]>=0.20.0",
# "pydantic>=2.0.0",
# "httpx>=0.25.0",
# "typer>=0.9.0",
# "numpy>=1.24,<1.26",
# "torch",
# "torchaudio",
# "peft",
# "chatterbox-tts @ git+https://github.com/abidlabs/chatterbox",
# ]
# ///
"""
Chatterbox TTS Model Server
Compatible with HuggingFace InferenceClient text_to_speech API
"""
import argparse
import os
import tempfile
from typing import Optional, Dict, Any
import uvicorn
import torch
import torchaudio as ta
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from chatterbox.tts import ChatterboxTTS
class TTSRequest(BaseModel):
inputs: str # text to synthesize
parameters: Optional[Dict[str, Any]] = None
extra_body: Optional[Dict[str, Any]] = None
app = FastAPI(title="Chatterbox TTS Server", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instance
model = None
@app.get("/")
async def health_check():
return {"status": "ok", "model": "chatterbox"}
@app.post("/api/v1/")
async def text_to_speech(request: TTSRequest):
"""
Text-to-speech endpoint compatible with HuggingFace InferenceClient
Generates speech using Chatterbox TTS model
"""
global model
if model is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please wait for server to initialize.",
)
text = request.inputs
extra_body = request.extra_body or {}
print(f"TTS Request - Text: '{text[:50]}...' Extra body: {extra_body}")
try:
# Get audio prompt from extra_body if provided
audio_prompt_path = extra_body.get("audio_url")
# Generate speech
if audio_prompt_path:
# Use voice cloning with audio prompt
wav = model.generate(text, audio_prompt_path=audio_prompt_path)
else:
# Use default voice
wav = model.generate(text)
# Convert tensor to bytes
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
ta.save(tmp_file.name, wav, model.sr)
tmp_file.flush()
# Read the saved audio file as bytes
with open(tmp_file.name, "rb") as f:
audio_data = f.read()
# Clean up temp file
os.unlink(tmp_file.name)
return Response(content=audio_data, media_type="audio/wav")
except Exception as e:
print(f"Error generating TTS: {str(e)}")
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
def auto_detect_device():
"""Auto-detect the best available device"""
if torch.cuda.is_available():
print("πŸš€ CUDA detected, using GPU acceleration")
return "cuda"
else:
print("πŸ’» CUDA not available, using CPU")
return "cpu"
def load_model(device=None):
"""Load the Chatterbox TTS model with automatic device detection"""
global model
if device is None:
device = auto_detect_device()
model = ChatterboxTTS.from_pretrained(device=device)
def main():
global model
parser = argparse.ArgumentParser(description="Start Chatterbox TTS Server")
parser.add_argument(
"--port", "-p", type=int, default=7861, help="Port to run server on"
)
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
args = parser.parse_args()
print(f"πŸŽ™οΈ Starting Chatterbox TTS Server on {args.host}:{args.port}")
print(f"🌐 Health check: http://localhost:{args.port}/")
print(f"πŸ”Š TTS endpoint: http://localhost:{args.port}/api/v1/")
load_model()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()