Rajhuggingface4253 commited on
Commit
c63a379
·
verified ·
1 Parent(s): 308a219

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -240
app.py CHANGED
@@ -2,332 +2,454 @@ import os
2
  import io
3
  import asyncio
4
  import time
 
5
  import numpy as np
 
6
  import soundfile as sf
7
  import subprocess
8
  import tempfile
9
  from concurrent.futures import ThreadPoolExecutor
10
- from typing import Optional, AsyncGenerator
11
  from contextlib import asynccontextmanager
12
  import logging
13
- import aiofiles
14
  import torch
15
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
16
  from fastapi.responses import Response, StreamingResponse
17
  from fastapi.middleware.cors import CORSMiddleware
 
18
 
19
- # Performance-focused configuration
20
- DEVICE = "cpu"
21
- MAX_WORKERS = 1 # Reduced for CPU efficiency
22
- SAMPLE_RATE = 24000
 
 
 
 
 
 
23
 
24
- # Minimal storage - no persistent files
 
 
 
 
 
 
25
  TEMP_AUDIO_DIR = "temp_audio"
 
26
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
27
 
28
- # Performance logging
29
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
30
- logger = logging.getLogger("NeuTTS-Perf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- class HighPerformanceTTSWrapper:
33
  def __init__(self, device: str = "cpu"):
34
  self.tts_model = None
35
  self.device = device
36
- self._ref_cache = {} # Cache encoded references
37
  self.load_model()
38
 
39
  def load_model(self):
40
- """Load model once and keep in memory."""
41
  try:
42
- logger.info("🚀 Loading NeuTTSAir model...")
 
43
  self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
44
- logger.info("✅ Model loaded successfully")
45
  except Exception as e:
46
  logger.error(f"❌ Model loading failed: {e}")
47
  raise
48
 
49
- def encode_reference_audio(self, audio_path: str) -> torch.Tensor:
50
- """Encode reference audio with caching."""
51
- cache_key = f"{os.path.getsize(audio_path)}_{os.path.getmtime(audio_path)}"
52
- if cache_key in self._ref_cache:
53
- return self._ref_cache[cache_key]
54
-
55
- ref_s = self.tts_model.encode_reference(audio_path)
56
- self._ref_cache[cache_key] = ref_s
57
- return ref_s
 
58
 
59
- def synthesize_complete(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
60
- """High-performance complete synthesis."""
61
- start_time = time.time()
 
 
 
 
 
 
 
62
 
63
- # Encode reference
64
- ref_s = self.encode_reference_audio(ref_audio_path)
65
 
66
- # Synthesize complete audio
67
  with torch.no_grad():
68
  audio = self.tts_model.infer(text, ref_s, reference_text)
69
-
70
- logger.info(f"🎯 Complete synthesis: {time.time() - start_time:.2f}s")
71
  return audio
72
 
73
- def synthesize_streaming(self, text: str, ref_audio_path: str, reference_text: str) -> AsyncGenerator[np.ndarray, None]:
74
- """True streaming synthesis with optimal chunking."""
75
- start_time = time.time()
 
 
76
 
77
- # Encode reference once
78
- ref_s = self.encode_reference_audio(ref_audio_path)
79
- encoding_time = time.time() - start_time
80
- logger.info(f"🔧 Reference encoded: {encoding_time:.2f}s")
81
 
82
- # Smart text chunking for optimal performance
83
- chunks = self._optimized_text_chunking(text)
84
- logger.info(f"📝 Split into {len(chunks)} chunks")
85
 
86
- # Stream chunks
87
- for i, chunk in enumerate(chunks):
88
- chunk_start = time.time()
 
 
 
 
 
89
  with torch.no_grad():
90
- audio_chunk = self.tts_model.infer(chunk, ref_s, reference_text)
91
 
92
- chunk_time = time.time() - chunk_start
93
- logger.info(f"🎵 Chunk {i+1}/{len(chunks)}: {chunk_time:.2f}s")
94
- yield audio_chunk
95
-
96
- total_time = time.time() - start_time
97
- logger.info(f"✅ Streaming complete: {total_time:.2f}s")
98
 
99
- def _optimized_text_chunking(self, text: str, max_chars: int = 200) -> list[str]:
100
- """Optimized chunking for TTS performance."""
101
- if len(text) <= max_chars:
102
- return [text]
103
-
104
- # Split by sentences first, then by length
105
- sentences = [s.strip() for s in text.split('.') if s.strip()]
106
- chunks = []
107
- current_chunk = ""
108
-
109
- for sentence in sentences:
110
- if len(current_chunk) + len(sentence) + 1 <= max_chars:
111
- current_chunk += (" " + sentence) if current_chunk else sentence
112
- else:
113
- if current_chunk:
114
- chunks.append(current_chunk)
115
- current_chunk = sentence
116
-
117
- if current_chunk:
118
- chunks.append(current_chunk)
119
-
120
- return chunks if chunks else [text]
121
 
122
- def audio_to_bytes(self, audio_data: np.ndarray, audio_format: str) -> bytes:
123
- """Convert audio to bytes efficiently."""
124
- audio_buffer = io.BytesIO()
125
- sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
126
- return audio_buffer.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # Global instances for performance
129
- tts_wrapper = HighPerformanceTTSWrapper(device=DEVICE)
130
- executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
131
 
132
- # FastAPI app with minimal overhead
133
  @asynccontextmanager
134
  async def lifespan(app: FastAPI):
135
- yield # Model already loaded
136
- executor.shutdown(wait=True)
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- app = FastAPI(lifespan=lifespan, title="NeuTTS High-Performance API")
139
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # Performance monitoring
142
- @app.get("/performance")
143
- async def performance_status():
 
 
 
144
  return {
145
- "status": "operational",
146
- "model_loaded": tts_wrapper.tts_model is not None,
147
  "device": DEVICE,
148
- "max_workers": MAX_WORKERS,
149
- "reference_cache_size": len(tts_wrapper._ref_cache)
 
 
 
 
 
 
 
150
  }
151
 
152
- # High-performance file operations
153
- async def save_and_convert_audio(upload_file: UploadFile) -> str:
154
- """Save and convert audio in one efficient operation."""
155
- # Create temp file
156
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir=TEMP_AUDIO_DIR) as tmp:
157
- temp_wav_path = tmp.name
 
 
 
 
158
 
159
- try:
160
- # Save uploaded file temporarily
161
- temp_upload_path = f"{temp_wav_path}.upload"
162
- async with aiofiles.open(temp_upload_path, 'wb') as f:
163
- content = await upload_file.read() # Read once
164
- await f.write(content)
165
-
166
- # Convert to WAV using subprocess (most efficient)
167
- cmd = [
168
- "ffmpeg", "-y", "-i", temp_upload_path,
169
- "-f", "wav", "-ar", str(SAMPLE_RATE), "-ac", "1",
170
- "-c:a", "pcm_s16le", temp_wav_path
171
- ]
172
-
173
- process = await asyncio.create_subprocess_exec(*cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
174
- await process.wait()
175
-
176
- # Cleanup upload file
177
- if os.path.exists(temp_upload_path):
178
- os.unlink(temp_upload_path)
179
-
180
- return temp_wav_path
181
-
182
- except Exception as e:
183
- # Cleanup on error
184
- if os.path.exists(temp_wav_path):
185
- os.unlink(temp_wav_path)
186
- if 'temp_upload_path' in locals() and os.path.exists(temp_upload_path):
187
- os.unlink(temp_upload_path)
188
- raise e
189
-
190
- async def cleanup_file(path: str):
191
- """Async file cleanup."""
192
- try:
193
- if os.path.exists(path):
194
- os.unlink(path)
195
- except:
196
- pass
197
 
198
- # High-performance endpoints
199
  @app.post("/synthesize", response_class=Response)
200
- async def synthesize_speech(
201
  text: str = Form(...),
202
  reference_text: str = Form(...),
203
- output_format: str = Form("wav"),
204
- reference_audio: UploadFile = File(...),
205
- background_tasks: BackgroundTasks = None
206
- ):
207
- """High-performance complete synthesis."""
208
- start_time = time.time()
209
- temp_path = None
 
 
210
 
211
- try:
212
- # 1. Process audio (fast)
213
- temp_path = await save_and_convert_audio(reference_audio)
214
- process_time = time.time() - start_time
215
- logger.info(f"📁 Audio processed: {process_time:.2f}s")
216
 
217
- # 2. Synthesize (blocking but efficient)
218
- audio_data = await asyncio.get_event_loop().run_in_executor(
219
- executor,
220
- tts_wrapper.synthesize_complete,
221
- text, temp_path, reference_text
222
  )
223
-
224
- # 3. Convert to bytes
225
- audio_bytes = tts_wrapper.audio_to_bytes(audio_data, output_format)
226
-
227
- total_time = time.time() - start_time
228
- logger.info(f"✅ Complete request: {total_time:.2f}s")
229
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  return Response(
231
  content=audio_bytes,
232
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
233
  headers={
234
- "X-Processing-Time": f"{total_time:.2f}s",
235
- "X-Audio-Length": f"{len(audio_data)/SAMPLE_RATE:.2f}s"
 
236
  }
237
  )
238
-
239
  except Exception as e:
240
  logger.error(f"Synthesis error: {e}")
241
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
242
  finally:
243
- if temp_path:
244
- if background_tasks:
245
- background_tasks.add_task(cleanup_file, temp_path)
246
- else:
247
- await cleanup_file(temp_path)
248
 
249
  @app.post("/synthesize/stream")
250
- async def stream_speech(
251
- text: str = Form(...),
252
  reference_text: str = Form(...),
253
- output_format: str = Form("mp3"),
254
- reference_audio: UploadFile = File(...)
255
- ):
256
- """True streaming with immediate delivery."""
257
- start_time = time.time()
258
- temp_path = None
 
 
 
 
 
 
 
259
 
260
  try:
261
- # Process audio first
262
- temp_path = await save_and_convert_audio(reference_audio)
263
- setup_time = time.time() - start_time
264
- logger.info(f"🎯 Streaming setup: {setup_time:.2f}s")
 
265
 
266
- async def generate_stream():
267
- """True streaming generator."""
 
 
 
 
268
  try:
269
- first_chunk_sent = False
270
- chunk_count = 0
271
-
272
- # Get the async generator
273
- audio_chunks = tts_wrapper.synthesize_streaming(text, temp_path, reference_text)
274
-
275
- # Stream chunks immediately as they're generated
276
- async for audio_chunk in audio_chunks:
277
- chunk_count += 1
278
-
279
- # Convert to bytes
280
- chunk_bytes = tts_wrapper.audio_to_bytes(audio_chunk, output_format)
281
-
282
- # Track first chunk timing
283
- if not first_chunk_sent:
284
- first_chunk_time = time.time() - start_time
285
- logger.info(f"🚀 FIRST CHUNK SENT: {first_chunk_time:.2f}s")
286
- first_chunk_sent = True
287
-
288
- logger.info(f"📦 Yielding chunk {chunk_count} ({len(chunk_bytes)} bytes)")
289
  yield chunk_bytes
290
-
291
- total_time = time.time() - start_time
292
- logger.info(f"🎉 Streaming completed: {total_time:.2f}s, {chunk_count} chunks")
293
-
294
  except Exception as e:
295
- logger.error(f"Stream error: {e}")
296
- raise
 
297
  finally:
298
- # Cleanup
299
- if temp_path:
300
- await cleanup_file(temp_path)
301
-
 
 
302
  return StreamingResponse(
303
- generate_stream(),
304
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
305
  headers={
306
- "Content-Disposition": "attachment; filename=stream.mp3",
307
  "Transfer-Encoding": "chunked",
308
  "Cache-Control": "no-cache",
309
- "X-Streaming": "true"
310
  }
311
  )
312
 
313
  except Exception as e:
314
- logger.error(f"Stream setup error: {e}")
315
- if temp_path:
316
- await cleanup_file(temp_path)
317
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
318
 
319
- @app.get("/")
320
- async def root():
321
- return {"message": "NeuTTS High-Performance API - Optimized for Speed"}
322
-
323
- if __name__ == "__main__":
324
- import uvicorn
325
- uvicorn.run(
326
- app,
327
- host="0.0.0.0",
328
- port=7860,
329
- workers=1,
330
- loop="asyncio",
331
- access_log=False, # Disable access logs for performance
332
- log_level="warning"
333
- )
 
2
  import io
3
  import asyncio
4
  import time
5
+ import shutil
6
  import numpy as np
7
+ import psutil
8
  import soundfile as sf
9
  import subprocess
10
  import tempfile
11
  from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Optional, Generator
13
  from contextlib import asynccontextmanager
14
  import logging
15
+ import aiofiles
16
  import torch
17
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
18
  from fastapi.responses import Response, StreamingResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
+ from pydantic import BaseModel, Field
21
 
22
+ # Ensure the cloned neutts-air repository is in the path
23
+ import sys
24
+ sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
25
+ from neuttsair.neutts import NeuTTSAir
26
+
27
+ # Configure logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger("NeuTTS-API")
30
+
31
+ # --- Configuration & Utility Functions ---
32
 
33
+ # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
34
+ DEVICE = "cpu"
35
+ # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
36
+ MAX_WORKERS = 2
37
+ tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
38
+ SAMPLE_RATE = 24000
39
+ CLEANUP_THRESHOLD = 3600 # 1 hour in seconds
40
  TEMP_AUDIO_DIR = "temp_audio"
41
+ GENERATED_AUDIO_DIR = "generated_audio"
42
  os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
43
+ os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
44
+
45
+ class TTSRequestModel(BaseModel):
46
+ """Model for non-file inputs to synthesis and streaming."""
47
+ text: str = Field(..., min_length=1, max_length=1000)
48
+ speed: float = Field(default=1.0, ge=0.5, le=2.0)
49
+ output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
50
+
51
 
52
+ def convert_to_wav_blocking(input_path: str) -> str:
53
+ """
54
+ NEW FUNCTION: Uses FFmpeg to convert any uploaded audio format (WebM, MP4, etc.)
55
+ to a 24kHz, 16-bit PCM WAV file, which is required by soundfile/libsndfile.
56
+ This function must run in the ThreadPoolExecutor.
57
+ """
58
+ # Create a unique temporary filename for the converted WAV file
59
+ # We use tempfile.NamedTemporaryFile to safely create a path
60
+ # and then delete the file handle so ffmpeg can write to it.
61
+ with tempfile.NamedTemporaryFile(suffix=".wav", dir=TEMP_AUDIO_DIR, delete=False) as tmp:
62
+ output_path = tmp.name
63
+
64
+ logger.info(f"Converting '{os.path.basename(input_path)}' to WAV (24kHz, mono) at {os.path.basename(output_path)}")
65
+
66
+ # FFmpeg command details:
67
+ # -y: overwrite output file if it exists
68
+ # -i: input file path
69
+ # -f wav: output format is WAV
70
+ # -ar 24000: set sample rate to 24000 (required by NeuTTS)
71
+ # -ac 1: set audio channels to 1 (mono)
72
+ # -c:a pcm_s16le: set codec to uncompressed 16-bit PCM (standard WAV)
73
+ command = [
74
+ "ffmpeg",
75
+ "-y",
76
+ "-i", input_path,
77
+ "-f", "wav",
78
+ "-ar", str(SAMPLE_RATE),
79
+ "-ac", "1",
80
+ "-c:a", "pcm_s16le",
81
+ output_path
82
+ ]
83
+
84
+ try:
85
+ # Run the FFmpeg command
86
+ # Use a short timeout to prevent runaway processes
87
+ result = subprocess.run(command, check=True, capture_output=True, text=True, timeout=30)
88
+ logger.info(f"FFmpeg conversion successful.")
89
+ return output_path
90
+ except subprocess.CalledProcessError as e:
91
+ logger.error(f"FFmpeg conversion failed: {e.stderr}")
92
+ # Clean up the output path if FFmpeg failed to write it
93
+ if os.path.exists(output_path):
94
+ os.unlink(output_path)
95
+ # Provide the last line of the FFmpeg error to the user
96
+ error_detail = e.stderr.splitlines()[-1] if e.stderr else "Unknown FFmpeg error."
97
+ raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}")
98
+ except subprocess.TimeoutExpired:
99
+ logger.error("FFmpeg conversion timed out.")
100
+ if os.path.exists(output_path):
101
+ os.unlink(output_path)
102
+ raise HTTPException(status_code=504, detail="Audio conversion timed out after 30 seconds.")
103
+ except Exception as e:
104
+ logger.error(f"General conversion error: {e}")
105
+ if os.path.exists(output_path):
106
+ os.unlink(output_path)
107
+ raise HTTPException(status_code=500, detail="An unexpected error occurred during audio conversion.")
108
+ # --- Model Wrapper and Logic ---
109
 
110
+ class NeuTTSWrapper:
111
  def __init__(self, device: str = "cpu"):
112
  self.tts_model = None
113
  self.device = device
 
114
  self.load_model()
115
 
116
  def load_model(self):
 
117
  try:
118
+ logger.info(f"Loading NeuTTSAir model on device: {self.device}")
119
+ # Ensure we respect the CPU configuration
120
  self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
121
+ logger.info("✅ NeuTTSAir model loaded successfully.")
122
  except Exception as e:
123
  logger.error(f"❌ Model loading failed: {e}")
124
  raise
125
 
126
+ def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
127
+ """Converts NumPy audio array to streamable bytes in the specified format."""
128
+ audio_buffer = io.BytesIO()
129
+ try:
130
+ sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
131
+ except Exception as e:
132
+ logger.error(f"Failed to write audio data to format {audio_format}: {e}")
133
+ raise
134
+ audio_buffer.seek(0)
135
+ return audio_buffer.read()
136
 
137
+ def _split_text_into_chunks(self, text: str) -> list[str]:
138
+ """Simple sentence splitting for streaming (can be enhanced with regex)."""
139
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
140
+ if not sentences:
141
+ sentences = [text.strip()]
142
+ return sentences
143
+
144
+ def generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray:
145
+ """Blocking synthesis for standard endpoint."""
146
+
147
 
148
+ ref_s = self.tts_model.encode_reference(ref_audio_path)
 
149
 
150
+ # 3. Infer full text
151
  with torch.no_grad():
152
  audio = self.tts_model.infer(text, ref_s, reference_text)
 
 
153
  return audio
154
 
155
+ def stream_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
156
+ """Sentence-by-Sentence Streaming (Blocking)."""
157
+ logger.info(f"Starting streaming synthesis for text length: {len(text)}")
158
+
159
+
160
 
161
+ ref_s = self.tts_model.encode_reference(ref_audio_path)
 
 
 
162
 
163
+ # 3. Split text
164
+ sentences = self._split_text_into_chunks(text)
 
165
 
166
+ # 4. Stream chunks
167
+ for i, sentence in enumerate(sentences):
168
+ if not sentence.strip():
169
+ continue
170
+
171
+ logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
172
+
173
+ # Infer sentence
174
  with torch.no_grad():
175
+ audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text)
176
 
177
+ # Convert and yield
178
+ yield self._convert_to_streamable_format(audio_chunk, audio_format)
179
+
180
+ logger.info("Streaming synthesis complete.")
 
 
181
 
182
+ # --- Asynchronous Offloading ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ async def run_blocking_task_async(func, *args, **kwargs):
185
+ """Offloads a blocking function call to the ThreadPoolExecutor."""
186
+ loop = asyncio.get_event_loop()
187
+ return await loop.run_in_executor(
188
+ tts_executor,
189
+ lambda: func(*args, **kwargs)
190
+ )
191
+
192
+ async def save_upload_file_async(upload_file: UploadFile) -> str:
193
+ """Asynchronously saves the UploadFile to disk."""
194
+ temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}")
195
+ try:
196
+ # Use asyncio to read the file chunks in a non-blocking manner
197
+ async with aiofiles.open(temp_filename, 'wb') as out_file:
198
+ while content := await upload_file.read(1024 * 1024):
199
+ await out_file.write(content)
200
+ return temp_filename
201
+ except Exception as e:
202
+ logger.error(f"Error saving file: {e}")
203
+ raise HTTPException(status_code=500, detail="Could not save reference audio file")
204
 
205
+ # --- FastAPI Lifespan Manager (Kokoro Feature) ---
 
 
206
 
 
207
  @asynccontextmanager
208
  async def lifespan(app: FastAPI):
209
+ """Modern lifespan management: initialize model on startup, shutdown executor."""
210
+ try:
211
+ app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
212
+ except Exception as e:
213
+ logger.error(f"Fatal startup error: {e}")
214
+ # Terminate the application if the model can't load
215
+ tts_executor.shutdown(wait=False)
216
+ raise RuntimeError("Model initialization failed.")
217
+
218
+ yield # Application serves requests
219
+
220
+ # Shutdown
221
+ logger.info("Shutting down ThreadPoolExecutor.")
222
+ tts_executor.shutdown(wait=False)
223
 
224
+ # --- FastAPI Application Setup ---
225
+ app = FastAPI(
226
+ title="NeuTTS Air Instant Cloning API",
227
+ version="2.0.0-PROD-ENHANCED",
228
+ docs_url="/docs",
229
+ lifespan=lifespan
230
+ )
231
+
232
+ app.add_middleware(
233
+ CORSMiddleware,
234
+ allow_origins=["*"],
235
+ allow_methods=["*"],
236
+ allow_headers=["*"],
237
+ )
238
+
239
+ # --- New Endpoints and Enhancements ---
240
+
241
+ @app.get("/")
242
+ async def root():
243
+ return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"}
244
 
245
+ @app.get("/health")
246
+ async def health_check():
247
+ """Enhanced health check (Kokoro Feature + Original Metrics)"""
248
+ mem = psutil.virtual_memory()
249
+ disk = psutil.disk_usage('/')
250
+
251
  return {
252
+ "status": "healthy",
253
+ "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
254
  "device": DEVICE,
255
+ "concurrency_limit": MAX_WORKERS,
256
+ "memory_usage": {
257
+ "total_gb": round(mem.total / (1024**3), 2),
258
+ "used_percent": mem.percent
259
+ },
260
+ "disk_usage": {
261
+ "total_gb": round(disk.total / (1024**3), 2),
262
+ "used_percent": disk.percent
263
+ }
264
  }
265
 
266
+ @app.delete("/cleanup")
267
+ async def cleanup_files():
268
+ """Maintenance endpoint to remove old generated and temporary files."""
269
+ await run_blocking_task_async(cleanup_files_blocking)
270
+ return {"message": "Cleanup initiated successfully."}
271
+
272
+ def cleanup_files_blocking():
273
+ """Blocking file cleanup logic (original NeuTTS feature)."""
274
+ now = time.time()
275
+ deleted_count = 0
276
 
277
+ for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]:
278
+ for filename in os.listdir(directory):
279
+ filepath = os.path.join(directory, filename)
280
+ if os.path.isfile(filepath):
281
+ try:
282
+ # Original cleanup logic: delete if older than CLEANUP_THRESHOLD
283
+ if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD:
284
+ os.remove(filepath)
285
+ deleted_count += 1
286
+ except Exception as e:
287
+ logger.warning(f"Failed to delete {filepath}: {e}")
288
+
289
+ logger.info(f"Cleanup completed: {deleted_count} files removed.")
290
+ return deleted_count
291
+
292
+
293
+ # --- Core Synthesis Endpoints ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
 
295
  @app.post("/synthesize", response_class=Response)
296
+ async def text_to_speech(
297
  text: str = Form(...),
298
  reference_text: str = Form(...),
299
+ speed: float = Form(1.0, ge=0.5, le=2.0),
300
+ output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
301
+ reference_audio: UploadFile = File(...)):
302
+ """
303
+ Standard blocking TTS endpoint with Multi-Format Output (Kokoro Feature).
304
+ Includes FFmpeg conversion for uploaded audio format compatibility.
305
+ """
306
+ if not hasattr(app.state, 'tts_wrapper'):
307
+ raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
308
 
309
+ # 1. Asynchronously save reference audio (original upload)
310
+ temp_ref_path = await save_upload_file_async(reference_audio)
311
+ converted_wav_path = None # NEW: Initialize for cleanup
312
+ start_time = time.time()
 
313
 
314
+ try:
315
+ # 2. **NEW STEP**: Convert the uploaded file (WebM, etc.) to a 24kHz WAV file using FFmpeg
316
+ converted_wav_path = await run_blocking_task_async(
317
+ convert_to_wav_blocking,
318
+ temp_ref_path
319
  )
320
+
321
+ # 3. Offload the ENTIRE blocking process (encode + infer) to a thread
322
+ audio_data = await run_blocking_task_async(
323
+ app.state.tts_wrapper.generate_speech_blocking,
324
+ text,
325
+ converted_wav_path, # IMPORTANT: Pass the CONVERTED WAV path
326
+ reference_text
327
+ )
328
+
329
+ # 4. Convert to requested format (Blocking, but usually fast)
330
+ audio_bytes = await run_blocking_task_async(
331
+ app.state.tts_wrapper._convert_to_streamable_format,
332
+ audio_data,
333
+ output_format
334
+ )
335
+
336
+ # 5. Save to disk (Original NeuTTS requirement)
337
+ audio_filename = f"tts_{time.time()}.{output_format}"
338
+ final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename)
339
+ await run_blocking_task_async(
340
+ lambda: open(final_path, 'wb').write(audio_bytes)
341
+ )
342
+
343
+ processing_time = time.time() - start_time
344
+ audio_duration = len(audio_data) / SAMPLE_RATE
345
  return Response(
346
  content=audio_bytes,
347
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
348
  headers={
349
+ "Content-Disposition": f"attachment; filename={audio_filename}",
350
+ "X-Processing-Time": f"{processing_time:.2f}s",
351
+ "X-Audio-Duration": f"{audio_duration:.2f}s"
352
  }
353
  )
 
354
  except Exception as e:
355
  logger.error(f"Synthesis error: {e}")
356
+ # Reraise HTTPExceptions that may have come from the conversion step
357
+ if isinstance(e, HTTPException):
358
+ raise
359
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
360
  finally:
361
+ # 6. Clean up BOTH the original file AND the converted WAV file
362
+ if os.path.exists(temp_ref_path):
363
+ os.unlink(temp_ref_path)
364
+ if converted_wav_path and os.path.exists(converted_wav_path):
365
+ os.unlink(converted_wav_path)
366
 
367
  @app.post("/synthesize/stream")
368
+ async def stream_text_to_speech_cloning(
369
+ text: str = Form(..., min_length=1, max_length=5000),
370
  reference_text: str = Form(...),
371
+ speed: float = Form(1.0, ge=0.5, le=2.0),
372
+ output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
373
+ reference_audio: UploadFile = File(...)):
374
+ """
375
+ Sentence-by-Sentence Streaming Endpoint.
376
+ Fixes race condition by moving cleanup into the streaming generator.
377
+ """
378
+ if not hasattr(app.state, 'tts_wrapper'):
379
+ raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
380
+
381
+ # 1. Asynchronously save reference audio (non-blocking)
382
+ temp_ref_path = await save_upload_file_async(reference_audio)
383
+ converted_wav_path = None # Initialize for cleanup
384
 
385
  try:
386
+ # 2. Convert the uploaded file (WebM, etc.) to a 24kHz WAV file
387
+ converted_wav_path = await run_blocking_task_async(
388
+ convert_to_wav_blocking,
389
+ temp_ref_path
390
+ )
391
 
392
+ # 2.5. CLEANUP ORIGINAL FILE IMMEDIATELY: It is no longer needed after conversion
393
+ if os.path.exists(temp_ref_path):
394
+ os.unlink(temp_ref_path)
395
+
396
+ # 3. Define the generator function, which will run in the thread pool
397
+ def stream_generator(path_to_delete: str):
398
  try:
399
+ # This logic uses the path_to_delete parameter, which is guaranteed to exist
400
+ for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking(
401
+ text,
402
+ path_to_delete, # Pass the CONVERTED WAV path
403
+ reference_text,
404
+ speed,
405
+ output_format
406
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
407
  yield chunk_bytes
 
 
 
 
408
  except Exception as e:
409
+ # Log the error and raise it to stop the stream
410
+ logger.error(f"Streaming generator error: {e}")
411
+ raise # Re-raise to ensure the stream terminates
412
  finally:
413
+ # 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE
414
+ if os.path.exists(path_to_delete):
415
+ os.unlink(path_to_delete)
416
+ logger.info(f"Cleaned up converted file: {path_to_delete}")
417
+
418
+ # Return StreamingResponse, passing the path to the generator
419
  return StreamingResponse(
420
+ stream_generator(converted_wav_path),
421
  media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
422
  headers={
423
+ "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
424
  "Transfer-Encoding": "chunked",
425
  "Cache-Control": "no-cache",
426
+ "X-Accel-Buffering": "no"
427
  }
428
  )
429
 
430
  except Exception as e:
431
+ logger.error(f"Streaming setup error: {e}")
432
+ # Clean up files only if the setup failed *before* starting the generator
433
+ if os.path.exists(temp_ref_path):
434
+ os.unlink(temp_ref_path)
435
+ if converted_wav_path and os.path.exists(converted_wav_path):
436
+ os.unlink(converted_wav_path)
437
+
438
+ # Reraise HTTPExceptions that may have come from the conversion step
439
+ if isinstance(e, HTTPException):
440
+ raise
441
+ raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}")
442
+ # Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4.
443
 
444
+ @app.get("/audio/{filename}")
445
+ async def get_audio(filename: str):
446
+ """Original NeuTTS feature to serve generated audio files."""
447
+ file_path = os.path.join(GENERATED_AUDIO_DIR, filename)
448
+ if not os.path.exists(file_path):
449
+ raise HTTPException(status_code=404, detail="Audio file not found")
450
+
451
+ return Response(
452
+ content=open(file_path, "rb").read(),
453
+ media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection
454
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
455
+ )