File size: 5,767 Bytes
3b32b80
3dd9c50
 
63cf9bb
 
 
 
 
26c5cf5
63cf9bb
 
 
3dd9c50
63cf9bb
3b32b80
63cf9bb
cc02ee6
63cf9bb
cc02ee6
63cf9bb
 
 
 
cc02ee6
63cf9bb
 
 
 
 
 
 
 
 
 
 
 
 
 
26c5cf5
63cf9bb
 
 
 
3b32b80
63cf9bb
3b32b80
 
 
63cf9bb
 
3b32b80
 
63cf9bb
 
 
 
 
3b32b80
63cf9bb
3b32b80
 
63cf9bb
3b32b80
 
 
63cf9bb
 
26c5cf5
3b32b80
 
63cf9bb
 
 
3b32b80
63cf9bb
 
 
3b32b80
63cf9bb
cc02ee6
63cf9bb
 
 
3b32b80
63cf9bb
 
 
 
3b32b80
63cf9bb
 
3b32b80
63cf9bb
 
3b32b80
 
 
 
63cf9bb
3b32b80
 
63cf9bb
 
 
3b32b80
63cf9bb
3b32b80
63cf9bb
 
 
3b32b80
63cf9bb
3b32b80
63cf9bb
3b32b80
 
 
 
63cf9bb
 
3b32b80
63cf9bb
 
3b32b80
3dd9c50
 
 
 
63cf9bb
3dd9c50
63cf9bb
 
 
3dd9c50
 
63cf9bb
 
 
 
3dd9c50
63cf9bb
 
3dd9c50
 
63cf9bb
 
 
 
 
 
 
 
 
 
 
 
 
 
3dd9c50
63cf9bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import io
import base64
import json
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import Response, JSONResponse
import soundfile as sf

from neutts_wrapper import NeuTTSWrapper

# --- Configuration & Global Objects ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Read device from environment variable, defaulting to 'cpu'
DEVICE = os.getenv("MODEL_DEVICE", "cpu")
# Use a ThreadPoolExecutor to run blocking ML code in a separate thread
tts_executor = ThreadPoolExecutor(max_workers=1)

# --- Lifespan Management (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Manages the model's lifecycle. It's loaded at startup and resources are
    cleaned up at shutdown.
    """
    logger.info("Application startup...")
    try:
        # Load the model wrapper into the application state
        app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
    except Exception as e:
        logger.error(f"FATAL: Model could not be loaded. {e}")
        app.state.tts_wrapper = None
    
    yield # The application is now running

    logger.info("Application shutdown...")
    tts_executor.shutdown(wait=True)

# --- FastAPI App Initialization ---
app = FastAPI(
    title="NeuTTS Air Production API",
    description="Production-ready Text-to-Speech with Voice Cloning",
    version="2.0.0",
    lifespan=lifespan
)

# --- Helper function for running blocking code ---
async def run_in_executor(func, *args):
    """Runs a blocking function in the thread pool to avoid blocking the server."""
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(tts_executor, func, *args)

# --- API Endpoints ---
@app.get("/")
async def root():
    return {"status": "online", "service": "NeuTTS Air API v2"}

@app.get("/health")
async def health_check():
    model_status = "loaded" if app.state.tts_wrapper else "degraded"
    return {"status": "healthy", "model_status": model_status, "device": DEVICE}

@app.post("/api/v1/synthesize")
async def synthesize_speech(
    ref_text: str = Form(...),
    gen_text: str = Form(...),
    ref_audio: UploadFile = File(...)
):
    if not app.state.tts_wrapper:
        raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")

    try:
        ref_audio_bytes = await ref_audio.read()
        
        # Run blocking ML code in the thread pool
        ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
        wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
        
        # Process audio in-memory
        buffer = io.BytesIO()
        sf.write(buffer, wav_data, 24000, format='WAV')
        buffer.seek(0)
        
        return Response(content=buffer.read(), media_type="audio/wav")

    except Exception as e:
        logger.error(f"Synthesis failed: {e}")
        raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")

@app.post("/api/v1/synthesize/b64")
async def synthesize_speech_base64(
    ref_text: str = Form(...),
    gen_text: str = Form(...),
    ref_audio: UploadFile = File(...)
):
    if not app.state.tts_wrapper:
        raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")

    try:
        ref_audio_bytes = await ref_audio.read()
        
        # Run blocking ML code in the thread pool
        ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
        wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
        
        # Process audio in-memory
        buffer = io.BytesIO()
        sf.write(buffer, wav_data, 24000, format='WAV')
        buffer.seek(0)
        
        audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
        
        return JSONResponse({"audio_data": audio_b64, "format": "wav"})

    except Exception as e:
        logger.error(f"Base64 synthesis failed: {e}")
        raise HTTPException(status_code=500, detail=f"Base64 synthesis failed: {str(e)}")

@app.post("/api/v1/batch-synthesize")
async def batch_synthesize(
    ref_text: str = Form(...),
    ref_audio: UploadFile = File(...),
    texts: str = Form(...)
):
    if not app.state.tts_wrapper:
        raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
    
    try:
        text_list = json.loads(texts)
        if not isinstance(text_list, list):
            raise ValueError("Texts must be a JSON array of strings.")
            
        ref_audio_bytes = await ref_audio.read()
        
        # Encode reference once, in the thread pool
        ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
        
        results = []
        for text in text_list:
            # Infer for each text
            wav_data = await run_in_executor(app.state.tts_wrapper.infer, text, ref_codes, ref_text)
            
            buffer = io.BytesIO()
            sf.write(buffer, wav_data, 24000, format='WAV')
            buffer.seek(0)
            audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
            results.append({"text": text, "audio_data": audio_b64})
            
        return JSONResponse({"generated_clips": results})

    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="Invalid JSON in 'texts' field.")
    except Exception as e:
        logger.error(f"Batch synthesis failed: {e}")
        raise HTTPException(status_code=500, detail=f"Batch synthesis failed: {str(e)}")