|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Chatterbox TTS Model Server |
|
|
Compatible with HuggingFace InferenceClient text_to_speech API |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import tempfile |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
import uvicorn |
|
|
import torch |
|
|
import torchaudio as ta |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import Response |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from chatterbox.tts import ChatterboxTTS |
|
|
|
|
|
|
|
|
class TTSRequest(BaseModel): |
|
|
inputs: str |
|
|
parameters: Optional[Dict[str, Any]] = None |
|
|
extra_body: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
app = FastAPI(title="Chatterbox TTS Server", version="1.0.0") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def health_check(): |
|
|
return {"status": "ok", "model": "chatterbox"} |
|
|
|
|
|
|
|
|
@app.post("/api/v1/") |
|
|
async def text_to_speech(request: TTSRequest): |
|
|
""" |
|
|
Text-to-speech endpoint compatible with HuggingFace InferenceClient |
|
|
Generates speech using Chatterbox TTS model |
|
|
""" |
|
|
global model |
|
|
|
|
|
if model is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Model not loaded. Please wait for server to initialize.", |
|
|
) |
|
|
|
|
|
text = request.inputs |
|
|
extra_body = request.extra_body or {} |
|
|
|
|
|
print(f"TTS Request - Text: '{text[:50]}...' Extra body: {extra_body}") |
|
|
|
|
|
try: |
|
|
|
|
|
audio_prompt_path = extra_body.get("audio_url") |
|
|
|
|
|
|
|
|
if audio_prompt_path: |
|
|
|
|
|
wav = model.generate(text, audio_prompt_path=audio_prompt_path) |
|
|
else: |
|
|
|
|
|
wav = model.generate(text) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
|
|
ta.save(tmp_file.name, wav, model.sr) |
|
|
tmp_file.flush() |
|
|
|
|
|
|
|
|
with open(tmp_file.name, "rb") as f: |
|
|
audio_data = f.read() |
|
|
|
|
|
|
|
|
os.unlink(tmp_file.name) |
|
|
|
|
|
return Response(content=audio_data, media_type="audio/wav") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating TTS: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
def auto_detect_device(): |
|
|
"""Auto-detect the best available device""" |
|
|
if torch.cuda.is_available(): |
|
|
print("π CUDA detected, using GPU acceleration") |
|
|
return "cuda" |
|
|
else: |
|
|
print("π» CUDA not available, using CPU") |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
def load_model(device=None): |
|
|
"""Load the Chatterbox TTS model with automatic device detection""" |
|
|
global model |
|
|
|
|
|
if device is None: |
|
|
device = auto_detect_device() |
|
|
|
|
|
model = ChatterboxTTS.from_pretrained(device=device) |
|
|
|
|
|
|
|
|
def main(): |
|
|
global model |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Start Chatterbox TTS Server") |
|
|
parser.add_argument( |
|
|
"--port", "-p", type=int, default=7861, help="Port to run server on" |
|
|
) |
|
|
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"ποΈ Starting Chatterbox TTS Server on {args.host}:{args.port}") |
|
|
print(f"π Health check: http://localhost:{args.port}/") |
|
|
print(f"π TTS endpoint: http://localhost:{args.port}/api/v1/") |
|
|
|
|
|
load_model() |
|
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|