File size: 9,639 Bytes
a1eb108
 
37a7804
6caac4d
 
37a7804
c63a379
7e59013
b01e4fa
a1eb108
 
6caac4d
7e59013
6caac4d
a1eb108
 
6caac4d
a1eb108
ac0fe7c
6caac4d
a1eb108
 
 
c63a379
 
 
b307da8
a1eb108
c63a379
a1eb108
c63a379
a1eb108
 
 
 
308a219
a1eb108
c63a379
 
 
a1eb108
6caac4d
5d68bda
a1eb108
5d68bda
a1eb108
 
c63a379
5d68bda
a1eb108
 
5d68bda
 
 
 
 
a1eb108
 
5d68bda
a1eb108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8d6527
c63a379
a1eb108
 
6caac4d
7e59013
6caac4d
7e59013
a1eb108
 
 
 
 
 
 
 
 
 
 
 
7e59013
a1eb108
6caac4d
c2ab408
a1eb108
 
 
c63a379
a1eb108
5d68bda
a1eb108
7e59013
 
 
a1eb108
c63a379
a1eb108
c63a379
a1eb108
 
 
 
 
c63a379
a1eb108
7e59013
c63a379
a1eb108
 
c63a379
 
 
a1eb108
c63a379
 
a1eb108
c63a379
 
 
a1eb108
6caac4d
c63a379
 
 
7e59013
c63a379
 
a1eb108
 
 
 
7e59013
37a7804
6caac4d
c63a379
6caac4d
4932f88
c63a379
a1eb108
 
 
c63a379
 
78223d3
 
a1eb108
 
 
 
 
78223d3
a1eb108
 
c63a379
a1eb108
78223d3
a1eb108
 
c63a379
a1eb108
c63a379
6caac4d
 
 
a1eb108
7e59013
 
a1eb108
 
 
78223d3
7e59013
c63a379
a1eb108
4932f88
c63a379
a1eb108
 
 
 
 
 
 
 
 
 
 
 
af37689
e94e39e
a1eb108
e94e39e
a1eb108
faa5294
a1eb108
faa5294
a1eb108
 
 
 
 
 
faa5294
a1eb108
 
faa5294
a1eb108
faa5294
a1eb108
 
faa5294
a1eb108
faa5294
a1eb108
 
faa5294
a1eb108
 
 
 
e94e39e
 
 
faa5294
a1eb108
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# app.py

import os
import io
import asyncio
import time
import psutil
import soundfile as sf
import subprocess
import numpy as np
import librosa # Needed for monkey-patching
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
import logging
from types import MethodType

import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import Response, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

# This will now work because the Dockerfile clones the repo
# and we add it to the path
import sys
sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
from neuttsair.neutts import NeuTTSAir

# --- Configuration & Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("NeuTTS-GGUF-API")

# Production-ready configuration via Environment Variables
BACKBONE_MODEL_PATH = os.getenv("BACKBONE_MODEL_PATH", "/app/models/neutts-air.gguf")
CODEC_REPO = os.getenv("CODEC_REPO", "neuphonic/neucodec-onnx-decoder") # Using ONNX for performance
DEVICE = "cpu" # llama-cpp handles its own device (CPU/GPU) management

MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2"))
tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
SAMPLE_RATE = 24000

# --- Core Utility Functions ---

async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
    """Converts uploaded audio to a 16kHz WAV for the encoder, in memory."""
    ffmpeg_command = [
        "ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000",
        "-ac", "1", "-c:a", "pcm_s16le", "pipe:1"
    ]
    proc = await asyncio.create_subprocess_exec(
        *ffmpeg_command, stdin=subprocess.PIPE,
        stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    wav_data, stderr_data = await proc.communicate(input=await upload_file.read())
    if proc.returncode != 0:
        error_message = stderr_data.decode()
        logger.error(f"In-memory conversion failed: {error_message}")
        error_detail = error_message.strip().splitlines()[-1]
        raise HTTPException(status_code=400, detail=f"Audio conversion failed: {error_detail}")
    return io.BytesIO(wav_data)

async def run_blocking_task_async(func, *args, **kwargs):
    """Offloads a blocking function call to the ThreadPoolExecutor."""
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(tts_executor, lambda: func(*args, **kwargs))

# --- Model Wrapper and Professional Integration ---

def _encode_reference_from_memory(self, ref_audio: io.BytesIO):
    """
    A replacement for the original encode_reference.
    This version reads from an in-memory BytesIO object instead of a file path,
    which is much faster for our API.
    """
    wav, _ = librosa.load(ref_audio, sr=16000, mono=True)
    wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
    with torch.no_grad():
        ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
    return ref_codes

class NeuTTSWrapper:
    def __init__(self):
        self.tts_model: NeuTTSAir | None = None
        self.load_model()

    def load_model(self):
        try:
            logger.info(f"Loading NeuTTSAir GGUF model from: {BACKBONE_MODEL_PATH}")
            self.tts_model = NeuTTSAir(
                backbone_repo=BACKBONE_MODEL_PATH,
                codec_repo=CODEC_REPO,
                backbone_device=DEVICE,
                codec_device=DEVICE
            )
            # ** MONKEY-PATCHING **: This is the professional way to adapt the library
            # without changing its source code. We replace its file-based function
            # with our memory-based one.
            self.tts_model.encode_reference = MethodType(_encode_reference_from_memory, self.tts_model)
            logger.info("✅ NeuTTSAir GGUF model loaded and patched successfully.")
        except Exception as e:
            logger.error(f"❌ Model loading failed: {e}", exc_info=True)
            raise

    def convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
        """Converts NumPy audio array to bytes in the specified format."""
        with io.BytesIO() as audio_buffer:
            sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
            return audio_buffer.getvalue()

# --- FastAPI Application Setup ---

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initializes the model on startup and shuts down the executor."""
    try:
        app.state.tts_wrapper = NeuTTSWrapper()
    except Exception as e:
        logger.error(f"Fatal startup error: Model could not be loaded. {e}")
        # Properly handle shutdown if model loading fails
        tts_executor.shutdown(wait=False, cancel_futures=True)
        raise RuntimeError("Model initialization failed. Application cannot start.") from e
    yield
    logger.info("Shutting down ThreadPoolExecutor.")
    tts_executor.shutdown(wait=True)

app = FastAPI(
    title="NeuTTS Air GGUF Cloning API",
    version="3.0.0-PROD-GGUF",
    lifespan=lifespan
)
app.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
)

# --- API Endpoints ---

@app.get("/")
async def root():
    return {"message": "NeuTTS Air GGUF API - Ready for High-Speed Voice Cloning"}

@app.get("/health")
async def health_check():
    mem = psutil.virtual_memory()
    return {
        "status": "healthy",
        "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
        "model_type": "GGUF",
        "backbone_path": BACKBONE_MODEL_PATH,
        "codec_repo": CODEC_REPO,
        "memory_usage_percent": mem.percent
    }

@app.post("/synthesize", response_class=Response)
async def text_to_speech(
    text: str = Form(...),
    reference_text: str = Form(...),
    output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
    reference_audio: UploadFile = File(...)
):
    """Standard blocking TTS endpoint optimized for GGUF."""
    start_time = time.time()
    try:
        converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)

        ref_codes = await run_blocking_task_async(
            app.state.tts_wrapper.tts_model.encode_reference,
            converted_wav_buffer
        )

        audio_data = await run_blocking_task_async(
            app.state.tts_wrapper.tts_model.infer,
            text, ref_codes, reference_text
        )

        audio_bytes = await run_blocking_task_async(
            app.state.tts_wrapper.convert_to_streamable_format,
            audio_data, output_format
        )

        processing_time = time.time() - start_time
        return Response(
            content=audio_bytes,
            media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
            headers={"X-Processing-Time": f"{processing_time:.2f}s"}
        )
    except Exception as e:
        logger.error(f"Synthesis error: {e}", exc_info=True)
        detail = str(e) if isinstance(e, HTTPException) else "An internal error occurred during synthesis."
        raise HTTPException(status_code=500, detail=detail)

@app.post("/synthesize/stream")
async def stream_text_to_speech_cloning(
    text: str = Form(..., min_length=1),
    reference_text: str = Form(...),
    output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
    reference_audio: UploadFile = File(...)
):
    """High-performance, sentence-by-sentence streaming using the GGUF backend."""
    try:
        converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
        ref_codes = await run_blocking_task_async(
            app.state.tts_wrapper.tts_model.encode_reference,
            converted_wav_buffer
        )
    except Exception as e:
        logger.error(f"Error during pre-processing for stream: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail="Failed to prepare reference audio for streaming.")

    async def stream_generator():
        # The model's infer_stream is a blocking generator. We must run it in the executor.
        loop = asyncio.get_event_loop()
        queue = asyncio.Queue()

        def producer():
            try:
                # This loop will block in the thread, but not the main event loop
                for audio_chunk in app.state.tts_wrapper.tts_model.infer_stream(text, ref_codes, reference_text):
                    # Convert chunk to the desired output format in the same thread
                    chunk_bytes = app.state.tts_wrapper.convert_to_streamable_format(audio_chunk, output_format)
                    # Put the result into the thread-safe asyncio queue
                    loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes)
            except Exception as e:
                logger.error(f"Error in streaming producer thread: {e}", exc_info=True)
                loop.call_soon_threadsafe(queue.put_nowait, e)
            finally:
                loop.call_soon_threadsafe(queue.put_nowait, None) # Signal end of stream

        # Start the blocking producer in the thread pool
        producer_task = loop.run_in_executor(tts_executor, producer)

        # The consumer runs in the main async event loop
        while True:
            item = await queue.get()
            if item is None:
                break
            if isinstance(item, Exception):
                raise item
            yield item
        await producer_task # Ensure the producer finishes cleanly

    return StreamingResponse(
        stream_generator(),
        media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
    )