File size: 3,749 Bytes
5e29ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.8"
# dependencies = [
#     "fastapi>=0.100.0",
#     "uvicorn[standard]>=0.20.0",
#     "pydantic>=2.0.0",
#     "httpx>=0.25.0",
#     "typer>=0.9.0",
# ]
# ///

"""
Chatterbox TTS Model Server - Mock Implementation
Compatible with HuggingFace InferenceClient text_to_speech API
"""

import argparse
import os
from pathlib import Path
from typing import Optional, Dict, Any

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel


class TTSRequest(BaseModel):
    inputs: str  # text to synthesize
    parameters: Optional[Dict[str, Any]] = None


class InferenceClientTTSRequest(BaseModel):
    inputs: str  # text to synthesize  
    extra_body: Optional[Dict[str, Any]] = None


app = FastAPI(title="Chatterbox TTS Server", version="1.0.0")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Path to sample audio file
SAMPLE_AUDIO_PATH = None


@app.get("/")
async def health_check():
    return {"status": "ok", "model": "ResembleAI/chatterbox"}


@app.post("/")
async def text_to_speech(request: TTSRequest):
    """
    Text-to-speech endpoint compatible with HuggingFace InferenceClient
    Always returns the same sample audio file for testing
    """
    if not SAMPLE_AUDIO_PATH or not os.path.exists(SAMPLE_AUDIO_PATH):
        raise HTTPException(
            status_code=500, 
            detail="Sample audio file not found. Please provide --sample-audio path."
        )
    
    print(f"TTS Request - Text: '{request.inputs[:50]}...' Parameters: {request.parameters}")
    
    # Return the sample audio file
    return FileResponse(
        SAMPLE_AUDIO_PATH,
        media_type="audio/wav",
        filename="generated_audio.wav"
    )


@app.post("/v1/text-to-speech")
async def inference_client_text_to_speech(request: InferenceClientTTSRequest):
    """
    InferenceClient-compatible endpoint at /v1/text-to-speech
    Always returns the same sample audio file for testing
    """
    if not SAMPLE_AUDIO_PATH or not os.path.exists(SAMPLE_AUDIO_PATH):
        raise HTTPException(
            status_code=500, 
            detail="Sample audio file not found. Please provide --sample-audio path."
        )
    
    print(f"InferenceClient TTS Request - Text: '{request.inputs[:50]}...' Extra body: {request.extra_body}")
    
    # Return the sample audio file
    return FileResponse(
        SAMPLE_AUDIO_PATH,
        media_type="audio/wav",
        filename="generated_audio.wav"
    )


def main():
    global SAMPLE_AUDIO_PATH
    
    parser = argparse.ArgumentParser(description="Start Chatterbox TTS Server")
    parser.add_argument("--port", "-p", type=int, default=7860, help="Port to run server on")
    parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
    parser.add_argument("--sample-audio", required=True, help="Path to sample audio file to return")
    
    args = parser.parse_args()
    
    # Validate sample audio file exists
    if not os.path.exists(args.sample_audio):
        print(f"Error: Sample audio file not found: {args.sample_audio}")
        exit(1)
    
    SAMPLE_AUDIO_PATH = args.sample_audio
    
    print(f"πŸŽ™οΈ Starting Chatterbox TTS Server on {args.host}:{args.port}")
    print(f"πŸ“ Using sample audio: {args.sample_audio}")
    print(f"🌐 API endpoint: http://localhost:{args.port}/")
    
    uvicorn.run(
        app,
        host=args.host,
        port=args.port,
        log_level="info"
    )


if __name__ == "__main__":
    main()