File size: 4,099 Bytes
5e29ba1
 
eef3408
5e29ba1
 
 
 
 
 
eef3408
 
 
 
 
5e29ba1
 
 
 
eef3408
5e29ba1
 
 
 
 
eef3408
5e29ba1
 
 
eef3408
 
5e29ba1
eef3408
5e29ba1
 
 
eef3408
 
5e29ba1
 
 
 
eef3408
5e29ba1
 
 
 
 
 
 
 
 
 
 
 
 
eef3408
 
5e29ba1
 
 
 
5a78f5f
5e29ba1
 
eef3408
5e29ba1
 
 
eef3408
5e29ba1
eef3408
 
 
5e29ba1
eef3408
 
5e29ba1
33b2a07
eef3408
 
33b2a07
eef3408
5e29ba1
eef3408
 
 
5e29ba1
eef3408
 
 
 
 
 
 
33b2a07
eef3408
 
 
 
33b2a07
eef3408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e29ba1
 
 
eef3408
33b2a07
5e29ba1
33b2a07
eef3408
33b2a07
5e29ba1
 
33b2a07
5e29ba1
eef3408
 
 
5a78f5f
33b2a07
 
5e29ba1
 
 
33b2a07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
#     "fastapi>=0.100.0",
#     "uvicorn[standard]>=0.20.0",
#     "pydantic>=2.0.0",
#     "httpx>=0.25.0",
#     "typer>=0.9.0",
#     "numpy>=1.24,<1.26",
#     "torch",
#     "torchaudio",
#     "peft",
#     "chatterbox-tts @ git+https://github.com/abidlabs/chatterbox",
# ]
# ///

"""
Chatterbox TTS Model Server
Compatible with HuggingFace InferenceClient text_to_speech API
"""

import argparse
import os
import tempfile
from typing import Optional, Dict, Any

import uvicorn
import torch
import torchaudio as ta
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from chatterbox.tts import ChatterboxTTS


class TTSRequest(BaseModel):
    inputs: str  # text to synthesize
    parameters: Optional[Dict[str, Any]] = None
    extra_body: Optional[Dict[str, Any]] = None


app = FastAPI(title="Chatterbox TTS Server", version="1.0.0")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global model instance
model = None


@app.get("/")
async def health_check():
    return {"status": "ok", "model": "chatterbox"}


@app.post("/api/v1/")
async def text_to_speech(request: TTSRequest):
    """
    Text-to-speech endpoint compatible with HuggingFace InferenceClient
    Generates speech using Chatterbox TTS model
    """
    global model

    if model is None:
        raise HTTPException(
            status_code=503,
            detail="Model not loaded. Please wait for server to initialize.",
        )

    text = request.inputs
    extra_body = request.extra_body or {}

    print(f"TTS Request - Text: '{text[:50]}...' Extra body: {extra_body}")

    try:
        # Get audio prompt from extra_body if provided
        audio_prompt_path = extra_body.get("audio_url")

        # Generate speech
        if audio_prompt_path:
            # Use voice cloning with audio prompt
            wav = model.generate(text, audio_prompt_path=audio_prompt_path)
        else:
            # Use default voice
            wav = model.generate(text)

        # Convert tensor to bytes
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
            ta.save(tmp_file.name, wav, model.sr)
            tmp_file.flush()

            # Read the saved audio file as bytes
            with open(tmp_file.name, "rb") as f:
                audio_data = f.read()

            # Clean up temp file
            os.unlink(tmp_file.name)

        return Response(content=audio_data, media_type="audio/wav")

    except Exception as e:
        print(f"Error generating TTS: {str(e)}")
        raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")


def auto_detect_device():
    """Auto-detect the best available device"""
    if torch.cuda.is_available():
        print("πŸš€ CUDA detected, using GPU acceleration")
        return "cuda"
    else:
        print("πŸ’» CUDA not available, using CPU")
        return "cpu"


def load_model(device=None):
    """Load the Chatterbox TTS model with automatic device detection"""
    global model

    if device is None:
        device = auto_detect_device()

    model = ChatterboxTTS.from_pretrained(device=device)


def main():
    global model

    parser = argparse.ArgumentParser(description="Start Chatterbox TTS Server")
    parser.add_argument(
        "--port", "-p", type=int, default=7861, help="Port to run server on"
    )
    parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
    args = parser.parse_args()

    print(f"πŸŽ™οΈ Starting Chatterbox TTS Server on {args.host}:{args.port}")
    print(f"🌐 Health check: http://localhost:{args.port}/")
    print(f"πŸ”Š TTS endpoint: http://localhost:{args.port}/api/v1/")

    load_model()

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")


if __name__ == "__main__":
    main()