File size: 4,099 Bytes
5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 5a78f5f 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 5e29ba1 33b2a07 eef3408 33b2a07 eef3408 5e29ba1 eef3408 5e29ba1 eef3408 33b2a07 eef3408 33b2a07 eef3408 5e29ba1 eef3408 33b2a07 5e29ba1 33b2a07 eef3408 33b2a07 5e29ba1 33b2a07 5e29ba1 eef3408 5a78f5f 33b2a07 5e29ba1 33b2a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "fastapi>=0.100.0",
# "uvicorn[standard]>=0.20.0",
# "pydantic>=2.0.0",
# "httpx>=0.25.0",
# "typer>=0.9.0",
# "numpy>=1.24,<1.26",
# "torch",
# "torchaudio",
# "peft",
# "chatterbox-tts @ git+https://github.com/abidlabs/chatterbox",
# ]
# ///
"""
Chatterbox TTS Model Server
Compatible with HuggingFace InferenceClient text_to_speech API
"""
import argparse
import os
import tempfile
from typing import Optional, Dict, Any
import uvicorn
import torch
import torchaudio as ta
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from chatterbox.tts import ChatterboxTTS
class TTSRequest(BaseModel):
inputs: str # text to synthesize
parameters: Optional[Dict[str, Any]] = None
extra_body: Optional[Dict[str, Any]] = None
app = FastAPI(title="Chatterbox TTS Server", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instance
model = None
@app.get("/")
async def health_check():
return {"status": "ok", "model": "chatterbox"}
@app.post("/api/v1/")
async def text_to_speech(request: TTSRequest):
"""
Text-to-speech endpoint compatible with HuggingFace InferenceClient
Generates speech using Chatterbox TTS model
"""
global model
if model is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Please wait for server to initialize.",
)
text = request.inputs
extra_body = request.extra_body or {}
print(f"TTS Request - Text: '{text[:50]}...' Extra body: {extra_body}")
try:
# Get audio prompt from extra_body if provided
audio_prompt_path = extra_body.get("audio_url")
# Generate speech
if audio_prompt_path:
# Use voice cloning with audio prompt
wav = model.generate(text, audio_prompt_path=audio_prompt_path)
else:
# Use default voice
wav = model.generate(text)
# Convert tensor to bytes
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
ta.save(tmp_file.name, wav, model.sr)
tmp_file.flush()
# Read the saved audio file as bytes
with open(tmp_file.name, "rb") as f:
audio_data = f.read()
# Clean up temp file
os.unlink(tmp_file.name)
return Response(content=audio_data, media_type="audio/wav")
except Exception as e:
print(f"Error generating TTS: {str(e)}")
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
def auto_detect_device():
"""Auto-detect the best available device"""
if torch.cuda.is_available():
print("π CUDA detected, using GPU acceleration")
return "cuda"
else:
print("π» CUDA not available, using CPU")
return "cpu"
def load_model(device=None):
"""Load the Chatterbox TTS model with automatic device detection"""
global model
if device is None:
device = auto_detect_device()
model = ChatterboxTTS.from_pretrained(device=device)
def main():
global model
parser = argparse.ArgumentParser(description="Start Chatterbox TTS Server")
parser.add_argument(
"--port", "-p", type=int, default=7861, help="Port to run server on"
)
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
args = parser.parse_args()
print(f"ποΈ Starting Chatterbox TTS Server on {args.host}:{args.port}")
print(f"π Health check: http://localhost:{args.port}/")
print(f"π TTS endpoint: http://localhost:{args.port}/api/v1/")
load_model()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()
|