Rajhuggingface4253 commited on
Commit
9a55341
·
verified ·
1 Parent(s): 6c72969

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -48
app.py CHANGED
@@ -14,10 +14,19 @@ import torch
14
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
15
  from fastapi.responses import Response, StreamingResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
- from pydantic import BaseModel, Field
18
  import re
19
  import hashlib
20
  from functools import lru_cache
 
 
 
 
 
 
 
 
 
 
21
  # Ensure the cloned neutts-air repository is in the path
22
  import sys
23
  sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
@@ -31,16 +40,16 @@ logger = logging.getLogger("NeuTTS-API")
31
 
32
  # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
33
  DEVICE = "cpu"
34
- # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
35
- MAX_WORKERS = 2
36
- tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
37
- SAMPLE_RATE = 24000
38
 
39
- class TTSRequestModel(BaseModel):
40
- """Model for non-file inputs to synthesis and streaming."""
41
- text: str = Field(..., min_length=1, max_length=1000)
42
- output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
43
 
 
 
 
 
44
 
45
  async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
46
  """
@@ -79,24 +88,104 @@ async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
79
  logger.info("In-memory FFmpeg conversion successful.")
80
  # Return the raw WAV data in a BytesIO buffer, ready for the model
81
  return io.BytesIO(wav_data)
82
- # --- Model Wrapper and Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  class NeuTTSWrapper:
85
- def __init__(self, device: str = "cpu"):
86
  self.tts_model = None
87
  self.device = device
 
 
88
  self.load_model()
89
 
90
  def load_model(self):
91
  try:
92
- logger.info(f"Loading NeuTTSAir model on device: {self.device}")
93
- # Ensure we respect the CPU configuration
94
- self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  logger.info("✅ NeuTTSAir model loaded successfully.")
 
 
 
 
96
  except Exception as e:
97
  logger.error(f"❌ Model loading failed: {e}")
98
  raise
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
101
  """Converts NumPy audio array to streamable bytes in the specified format."""
102
  audio_buffer = io.BytesIO()
@@ -108,16 +197,87 @@ class NeuTTSWrapper:
108
  audio_buffer.seek(0)
109
  return audio_buffer.read()
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def _split_text_into_chunks(self, text: str) -> list[str]:
112
  """
113
- Splits text into sentences OR clauses using a robust regex.
114
- This is fast, library-free, and now handles commas.
115
  """
116
- # This regex now finds all sequences of characters that are not a sentence-ending
117
- # or clause-ending punctuation mark, followed by that punctuation.
118
- # The only change is adding ',' to the character sets.
119
- chunks = re.findall(r'[^.,!?]+[.,!?]*', text)
120
- return [c.strip() for c in chunks if c.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  @lru_cache(maxsize=32)
123
  def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor:
@@ -137,11 +297,58 @@ class NeuTTSWrapper:
137
  # 2. Get the encoding from the cache (or create it if new)
138
  ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
139
 
140
- # 3. Infer full text
141
  with torch.no_grad():
142
  audio = self.tts_model.infer(text, ref_s, reference_text)
143
  return audio
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # --- Asynchronous Offloading ---
147
 
@@ -153,17 +360,23 @@ async def run_blocking_task_async(func, *args, **kwargs):
153
  lambda: func(*args, **kwargs)
154
  )
155
 
156
-
157
- # --- FastAPI Lifespan Manager (Kokoro Feature) ---
158
 
159
  @asynccontextmanager
160
  async def lifespan(app: FastAPI):
161
- """Modern lifespan management: initialize model on startup, shutdown executor."""
162
  try:
163
- app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
 
 
 
 
 
 
 
 
164
  except Exception as e:
165
  logger.error(f"Fatal startup error: {e}")
166
- # Terminate the application if the model can't load
167
  tts_executor.shutdown(wait=False)
168
  raise RuntimeError("Model initialization failed.")
169
 
@@ -175,8 +388,8 @@ async def lifespan(app: FastAPI):
175
 
176
  # --- FastAPI Application Setup ---
177
  app = FastAPI(
178
- title="NeuTTS Air Instant Cloning API",
179
- version="2.0.0-PROD-ENHANCED",
180
  docs_url="/docs",
181
  lifespan=lifespan
182
  )
@@ -188,23 +401,28 @@ app.add_middleware(
188
  allow_headers=["*"],
189
  )
190
 
191
- # --- New Endpoints and Enhancements ---
192
 
193
  @app.get("/")
194
  async def root():
195
- return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"}
196
 
197
  @app.get("/health")
198
  async def health_check():
199
- """Enhanced health check (Kokoro Feature + Original Metrics)"""
200
  mem = psutil.virtual_memory()
201
  disk = psutil.disk_usage('/')
202
 
 
 
 
 
203
  return {
204
  "status": "healthy",
205
  "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
206
  "device": DEVICE,
207
  "concurrency_limit": MAX_WORKERS,
 
208
  "memory_usage": {
209
  "total_gb": round(mem.total / (1024**3), 2),
210
  "used_percent": mem.percent
@@ -215,8 +433,6 @@ async def health_check():
215
  }
216
  }
217
 
218
-
219
-
220
  # --- Core Synthesis Endpoints ---
221
 
222
  @app.post("/synthesize", response_class=Response)
@@ -226,7 +442,7 @@ async def text_to_speech(
226
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
227
  reference_audio: UploadFile = File(...)):
228
  """
229
- Standard blocking TTS endpoint with in-memory processing and caching.
230
  """
231
  if not hasattr(app.state, 'tts_wrapper'):
232
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
@@ -237,11 +453,11 @@ async def text_to_speech(
237
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
238
  ref_audio_bytes = converted_wav_buffer.getvalue()
239
 
240
- # 2. Offload the blocking AI process (now faster with caching)
241
  audio_data = await run_blocking_task_async(
242
  app.state.tts_wrapper.generate_speech_blocking,
243
  text,
244
- ref_audio_bytes, # Pass bytes, not a path
245
  reference_text
246
  )
247
 
@@ -254,13 +470,17 @@ async def text_to_speech(
254
 
255
  processing_time = time.time() - start_time
256
  audio_duration = len(audio_data) / SAMPLE_RATE
 
 
 
257
  return Response(
258
  content=audio_bytes,
259
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
260
  headers={
261
  "Content-Disposition": f"attachment; filename=tts_output.{output_format}",
262
  "X-Processing-Time": f"{processing_time:.2f}s",
263
- "X-Audio-Duration": f"{audio_duration:.2f}s"
 
264
  }
265
  )
266
  except Exception as e:
@@ -276,15 +496,14 @@ async def stream_text_to_speech_cloning(
276
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
277
  reference_audio: UploadFile = File(...)):
278
  """
279
- Sentence-by-Sentence Streaming using a high-performance, asyncio-native
280
- look-ahead pipeline. This ensures true overlap of CPU work and network I/O.
281
  """
282
  if not hasattr(app.state, 'tts_wrapper'):
283
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
284
 
285
  async def stream_generator():
286
  loop = asyncio.get_event_loop()
287
- q = asyncio.Queue(maxsize=MAX_WORKERS + 1) # Queue size based on workers
288
 
289
  async def producer():
290
  try:
@@ -301,6 +520,7 @@ async def stream_text_to_speech_cloning(
301
  )
302
 
303
  sentences = app.state.tts_wrapper._split_text_into_chunks(text)
 
304
 
305
  def process_chunk(sentence_text):
306
  with torch.no_grad():
@@ -321,27 +541,25 @@ async def stream_text_to_speech_cloning(
321
  producer_task = asyncio.create_task(producer())
322
 
323
  # --- High-Performance Consumer with Look-Ahead ---
324
- # Get the first task from the queue to start the process.
325
  current_task = await q.get()
326
 
327
  while current_task is not None:
328
- # Simultaneously, get the NEXT task from the queue.
329
- # This allows the next chunk to start processing while we wait for the current one.
330
  next_task = await q.get()
331
 
332
- # Now, wait for the CURRENT task to finish.
333
  if isinstance(current_task, Exception):
334
  raise current_task
335
 
336
  chunk_bytes = await current_task
337
  yield chunk_bytes
338
 
339
- # The next task becomes the current task for the next iteration.
340
  current_task = next_task
341
 
342
  await producer_task
343
 
344
  return StreamingResponse(
345
  stream_generator(),
346
- media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
347
- )
 
 
 
 
14
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
15
  from fastapi.responses import Response, StreamingResponse
16
  from fastapi.middleware.cors import CORSMiddleware
 
17
  import re
18
  import hashlib
19
  from functools import lru_cache
20
+
21
+ # ONNX Runtime import
22
+ try:
23
+ import onnxruntime as ort
24
+ ONNX_AVAILABLE = True
25
+ logger.info("✅ ONNX Runtime available")
26
+ except ImportError:
27
+ ONNX_AVAILABLE = False
28
+ logger.warning("⚠️ ONNX Runtime not available, falling back to PyTorch")
29
+
30
  # Ensure the cloned neutts-air repository is in the path
31
  import sys
32
  sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
 
40
 
41
  # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
42
  DEVICE = "cpu"
 
 
 
 
43
 
44
+ # ONNX Configuration
45
+ USE_ONNX = True and ONNX_AVAILABLE # Auto-disable if ONNX not available
46
+ ONNX_MODEL_DIR = "onnx_models"
47
+ os.makedirs(ONNX_MODEL_DIR, exist_ok=True)
48
 
49
+ # Configure Max Workers for concurrent synthesis threads
50
+ MAX_WORKERS = min(4, (os.cpu_count() or 2))
51
+ tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
52
+ SAMPLE_RATE = 24000
53
 
54
  async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
55
  """
 
88
  logger.info("In-memory FFmpeg conversion successful.")
89
  # Return the raw WAV data in a BytesIO buffer, ready for the model
90
  return io.BytesIO(wav_data)
91
+
92
+ # --- ONNX Optimized Model Wrapper ---
93
+
94
+ class NeuTTSONNXWrapper:
95
+ """ONNX optimized wrapper for NeuTTS model inference"""
96
+
97
+ def __init__(self, onnx_model_path: str):
98
+ self.session_options = ort.SessionOptions()
99
+
100
+ # Optimize for CPU performance
101
+ self.session_options.intra_op_num_threads = os.cpu_count() or 4
102
+ self.session_options.inter_op_num_threads = 2
103
+ self.session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
104
+ self.session_options.enable_profiling = False
105
+
106
+ # Use CPU execution provider
107
+ providers = ['CPUExecutionProvider']
108
+
109
+ self.session = ort.InferenceSession(
110
+ onnx_model_path,
111
+ sess_options=self.session_options,
112
+ providers=providers
113
+ )
114
+
115
+ # Get model metadata
116
+ self.input_names = [input.name for input in self.session.get_inputs()]
117
+ self.output_names = [output.name for output in self.session.get_outputs()]
118
+
119
+ logger.info(f"✅ ONNX model loaded: {onnx_model_path}")
120
+ logger.info(f" Inputs: {self.input_names}")
121
+ logger.info(f" Outputs: {self.output_names}")
122
 
123
  class NeuTTSWrapper:
124
+ def __init__(self, device: str = "cpu", use_onnx: bool = USE_ONNX):
125
  self.tts_model = None
126
  self.device = device
127
+ self.use_onnx = use_onnx
128
+ self.onnx_wrapper = None
129
  self.load_model()
130
 
131
  def load_model(self):
132
  try:
133
+ logger.info(f"Loading NeuTTSAir model on device: {self.device} (ONNX: {self.use_onnx})")
134
+
135
+ # Configure phonemizer for better performance
136
+ os.environ['PHONEMIZER_OPTIMIZE'] = '1'
137
+ os.environ['PHONEMIZER_VERBOSE'] = '0'
138
+
139
+ # Use ONNX codec decoder for maximum speed if available
140
+ codec_repo = "neuphonic/neucodec-onnx-decoder" if self.use_onnx else "neuphonic/neucodec"
141
+
142
+ self.tts_model = NeuTTSAir(
143
+ backbone_device=self.device,
144
+ codec_device=self.device,
145
+ codec_repo=codec_repo
146
+ )
147
+
148
+ # Initialize ONNX if enabled
149
+ if self.use_onnx:
150
+ self._initialize_onnx()
151
+
152
  logger.info("✅ NeuTTSAir model loaded successfully.")
153
+
154
+ # Test phonemizer with sample text
155
+ self._test_phonemizer()
156
+
157
  except Exception as e:
158
  logger.error(f"❌ Model loading failed: {e}")
159
  raise
160
 
161
+ def _initialize_onnx(self):
162
+ """Initialize ONNX components for optimized inference"""
163
+ try:
164
+ # Check if ONNX model exists, if not we'll use PyTorch fallback
165
+ onnx_model_path = os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")
166
+
167
+ if os.path.exists(onnx_model_path):
168
+ self.onnx_wrapper = NeuTTSONNXWrapper(onnx_model_path)
169
+ logger.info("✅ ONNX optimization enabled")
170
+ else:
171
+ logger.warning("⚠️ ONNX model not found, using PyTorch backend")
172
+ self.use_onnx = False
173
+
174
+ except Exception as e:
175
+ logger.warning(f"⚠️ ONNX initialization failed: {e}, using PyTorch backend")
176
+ self.use_onnx = False
177
+
178
+ def _test_phonemizer(self):
179
+ """Test phonemizer with sample text to catch issues early."""
180
+ try:
181
+ test_text = "Hello world this is a test."
182
+ # This will trigger phonemizer initialization and catch config issues
183
+ with torch.no_grad():
184
+ _ = self.tts_model.infer(test_text, torch.randn(1, 512), test_text)
185
+ logger.info("✅ Phonemizer tested successfully")
186
+ except Exception as e:
187
+ logger.warning(f"⚠️ Phonemizer test had issues: {e}")
188
+
189
  def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
190
  """Converts NumPy audio array to streamable bytes in the specified format."""
191
  audio_buffer = io.BytesIO()
 
197
  audio_buffer.seek(0)
198
  return audio_buffer.read()
199
 
200
+ def _preprocess_text_for_phonemizer(self, text: str) -> str:
201
+ """
202
+ Clean text for phonemizer to prevent word count mismatches.
203
+ This eliminates the warnings and significantly speeds up processing.
204
+ """
205
+ # Remove or replace problematic characters
206
+ text = re.sub(r'[^\w\s\.\,\!\?\-\'\"]', '', text) # Keep only safe chars
207
+
208
+ # Normalize whitespace
209
+ text = ' '.join(text.split())
210
+
211
+ # Ensure proper sentence separation for phonemizer
212
+ text = re.sub(r'\.\s*', '. ', text) # Standardize periods
213
+ text = re.sub(r'\?\s*', '? ', text) # Standardize question marks
214
+ text = re.sub(r'\!\s*', '! ', text) # Standardize exclamation marks
215
+
216
+ return text.strip()
217
+
218
  def _split_text_into_chunks(self, text: str) -> list[str]:
219
  """
220
+ Enhanced text splitting that's phonemizer-friendly.
221
+ Pre-processes each chunk to avoid word count mismatches.
222
  """
223
+ # First, preprocess the entire text
224
+ clean_text = self._preprocess_text_for_phonemizer(text)
225
+
226
+ # Use more robust sentence splitting
227
+ sentence_endings = r'[.!?]+'
228
+ chunks = []
229
+
230
+ # Split on sentence endings while preserving the endings
231
+ start = 0
232
+ for match in re.finditer(sentence_endings, clean_text):
233
+ end = match.end()
234
+ chunk = clean_text[start:end].strip()
235
+ if chunk:
236
+ chunks.append(chunk)
237
+ start = end
238
+
239
+ # Add any remaining text
240
+ if start < len(clean_text):
241
+ remaining = clean_text[start:].strip()
242
+ if remaining:
243
+ chunks.append(remaining)
244
+
245
+ # If no sentence endings found, split by commas or length
246
+ if not chunks:
247
+ chunks = self._fallback_chunking(clean_text)
248
+
249
+ return [chunk for chunk in chunks if chunk.strip()]
250
+
251
+ def _fallback_chunking(self, text: str) -> list[str]:
252
+ """Fallback chunking when no sentence endings are found."""
253
+ # Split by commas first
254
+ comma_chunks = [chunk.strip() + ',' for chunk in text.split(',') if chunk.strip()]
255
+ if comma_chunks:
256
+ # Remove trailing comma from last chunk
257
+ if comma_chunks[-1].endswith(','):
258
+ comma_chunks[-1] = comma_chunks[-1][:-1]
259
+ return comma_chunks
260
+
261
+ # Fallback to length-based chunking
262
+ max_chunk_length = 150
263
+ words = text.split()
264
+ chunks = []
265
+ current_chunk = []
266
+
267
+ for word in words:
268
+ current_chunk.append(word)
269
+ if len(' '.join(current_chunk)) > max_chunk_length:
270
+ if len(current_chunk) > 1:
271
+ chunks.append(' '.join(current_chunk[:-1]))
272
+ current_chunk = [current_chunk[-1]]
273
+ else:
274
+ chunks.append(' '.join(current_chunk))
275
+ current_chunk = []
276
+
277
+ if current_chunk:
278
+ chunks.append(' '.join(current_chunk))
279
+
280
+ return chunks
281
 
282
  @lru_cache(maxsize=32)
283
  def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor:
 
297
  # 2. Get the encoding from the cache (or create it if new)
298
  ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
299
 
300
+ # 3. Infer full text (ONNX optimized if available)
301
  with torch.no_grad():
302
  audio = self.tts_model.infer(text, ref_s, reference_text)
303
  return audio
304
 
305
+ # --- ONNX Conversion Function ---
306
+
307
+ def convert_model_to_onnx():
308
+ """Convert PyTorch model to ONNX format for optimized inference"""
309
+ try:
310
+ from transformers import AutoModelForCausalLM, AutoTokenizer
311
+ import torch.onnx
312
+
313
+ model_repo = "neuphonic/neutts-air"
314
+ onnx_path = os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")
315
+
316
+ logger.info("Starting ONNX conversion...")
317
+
318
+ # Load original model
319
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
320
+ model = AutoModelForCausalLM.from_pretrained(
321
+ model_repo,
322
+ torch_dtype=torch.float32 # Use float32 for better ONNX compatibility
323
+ ).cpu()
324
+ model.eval()
325
+
326
+ # Create dummy input (typical sequence length)
327
+ dummy_input = torch.randint(0, tokenizer.vocab_size, (1, 512), dtype=torch.long)
328
+
329
+ # Export to ONNX
330
+ torch.onnx.export(
331
+ model,
332
+ dummy_input,
333
+ onnx_path,
334
+ input_names=['input_ids'],
335
+ output_names=['logits'],
336
+ dynamic_axes={
337
+ 'input_ids': {0: 'batch_size', 1: 'sequence_length'},
338
+ 'logits': {0: 'batch_size', 1: 'sequence_length'}
339
+ },
340
+ opset_version=14,
341
+ do_constant_folding=True,
342
+ export_params=True,
343
+ verbose=False
344
+ )
345
+
346
+ logger.info(f"✅ ONNX conversion successful: {onnx_path}")
347
+ return True
348
+
349
+ except Exception as e:
350
+ logger.error(f"❌ ONNX conversion failed: {e}")
351
+ return False
352
 
353
  # --- Asynchronous Offloading ---
354
 
 
360
  lambda: func(*args, **kwargs)
361
  )
362
 
363
+ # --- FastAPI Lifespan Manager ---
 
364
 
365
  @asynccontextmanager
366
  async def lifespan(app: FastAPI):
367
+ """Modern lifespan management: initialize model on startup with ONNX optimization."""
368
  try:
369
+ # Convert to ONNX on first run if enabled but model doesn't exist
370
+ if USE_ONNX and not os.path.exists(os.path.join(ONNX_MODEL_DIR, "neutts_backbone.onnx")):
371
+ logger.info("First run: Converting model to ONNX for optimization...")
372
+ success = await run_blocking_task_async(convert_model_to_onnx)
373
+ if not success:
374
+ logger.warning("ONNX conversion failed, using PyTorch backend")
375
+
376
+ app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE, use_onnx=USE_ONNX)
377
+
378
  except Exception as e:
379
  logger.error(f"Fatal startup error: {e}")
 
380
  tts_executor.shutdown(wait=False)
381
  raise RuntimeError("Model initialization failed.")
382
 
 
388
 
389
  # --- FastAPI Application Setup ---
390
  app = FastAPI(
391
+ title="NeuTTS Air Instant Cloning API (ONNX Optimized)",
392
+ version="2.1.0-ONNX",
393
  docs_url="/docs",
394
  lifespan=lifespan
395
  )
 
401
  allow_headers=["*"],
402
  )
403
 
404
+ # --- Endpoints ---
405
 
406
  @app.get("/")
407
  async def root():
408
+ return {"message": "NeuTTS Air API v2.1 - ONNX Optimized for Speed"}
409
 
410
  @app.get("/health")
411
  async def health_check():
412
+ """Enhanced health check with ONNX status."""
413
  mem = psutil.virtual_memory()
414
  disk = psutil.disk_usage('/')
415
 
416
+ onnx_status = "enabled" if USE_ONNX else "disabled"
417
+ if hasattr(app.state, 'tts_wrapper'):
418
+ onnx_status = "active" if app.state.tts_wrapper.use_onnx else "fallback"
419
+
420
  return {
421
  "status": "healthy",
422
  "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
423
  "device": DEVICE,
424
  "concurrency_limit": MAX_WORKERS,
425
+ "onnx_optimization": onnx_status,
426
  "memory_usage": {
427
  "total_gb": round(mem.total / (1024**3), 2),
428
  "used_percent": mem.percent
 
433
  }
434
  }
435
 
 
 
436
  # --- Core Synthesis Endpoints ---
437
 
438
  @app.post("/synthesize", response_class=Response)
 
442
  output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
443
  reference_audio: UploadFile = File(...)):
444
  """
445
+ Standard blocking TTS endpoint with in-memory processing and ONNX optimization.
446
  """
447
  if not hasattr(app.state, 'tts_wrapper'):
448
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
 
453
  converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
454
  ref_audio_bytes = converted_wav_buffer.getvalue()
455
 
456
+ # 2. Offload the blocking AI process (ONNX optimized if available)
457
  audio_data = await run_blocking_task_async(
458
  app.state.tts_wrapper.generate_speech_blocking,
459
  text,
460
+ ref_audio_bytes,
461
  reference_text
462
  )
463
 
 
470
 
471
  processing_time = time.time() - start_time
472
  audio_duration = len(audio_data) / SAMPLE_RATE
473
+
474
+ logger.info(f"✅ Synthesis completed in {processing_time:.2f}s (ONNX: {app.state.tts_wrapper.use_onnx})")
475
+
476
  return Response(
477
  content=audio_bytes,
478
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
479
  headers={
480
  "Content-Disposition": f"attachment; filename=tts_output.{output_format}",
481
  "X-Processing-Time": f"{processing_time:.2f}s",
482
+ "X-Audio-Duration": f"{audio_duration:.2f}s",
483
+ "X-ONNX-Optimized": str(app.state.tts_wrapper.use_onnx)
484
  }
485
  )
486
  except Exception as e:
 
496
  output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
497
  reference_audio: UploadFile = File(...)):
498
  """
499
+ Sentence-by-Sentence Streaming with ONNX optimization.
 
500
  """
501
  if not hasattr(app.state, 'tts_wrapper'):
502
  raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
503
 
504
  async def stream_generator():
505
  loop = asyncio.get_event_loop()
506
+ q = asyncio.Queue(maxsize=MAX_WORKERS + 1)
507
 
508
  async def producer():
509
  try:
 
520
  )
521
 
522
  sentences = app.state.tts_wrapper._split_text_into_chunks(text)
523
+ logger.info(f"Streaming {len(sentences)} chunks (ONNX: {app.state.tts_wrapper.use_onnx})")
524
 
525
  def process_chunk(sentence_text):
526
  with torch.no_grad():
 
541
  producer_task = asyncio.create_task(producer())
542
 
543
  # --- High-Performance Consumer with Look-Ahead ---
 
544
  current_task = await q.get()
545
 
546
  while current_task is not None:
 
 
547
  next_task = await q.get()
548
 
 
549
  if isinstance(current_task, Exception):
550
  raise current_task
551
 
552
  chunk_bytes = await current_task
553
  yield chunk_bytes
554
 
 
555
  current_task = next_task
556
 
557
  await producer_task
558
 
559
  return StreamingResponse(
560
  stream_generator(),
561
+ media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
562
+ headers={
563
+ "X-ONNX-Optimized": str(app.state.tts_wrapper.use_onnx)
564
+ }
565
+ )