Rajhuggingface4253 commited on
Commit
6caac4d
·
verified ·
1 Parent(s): 2ed48de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -688
app.py CHANGED
@@ -1,770 +1,362 @@
1
  import os
2
- import sys
 
3
  import time
4
- import gc
5
- import torch
6
  import numpy as np
7
- import asyncio
8
- import aiofiles
9
- import re
10
- import io
11
- from concurrent.futures import ThreadPoolExecutor
12
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
13
- from fastapi.responses import JSONResponse, FileResponse, StreamingResponse, Response
14
- from fastapi.middleware.cors import CORSMiddleware
15
- from pydantic import BaseModel, Field
16
- from typing import Optional, Dict, Any, Generator, List
17
  import psutil
18
- import logging
19
  import soundfile as sf
 
 
20
  from contextlib import asynccontextmanager
 
 
 
 
 
 
21
 
22
-
23
-
24
- os.environ['HF_HOME'] = '/app/cache'
25
- os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/cache'
26
-
27
- # Add NeuTTS Air to path
28
- sys.path.append("neutts-air")
29
 
30
  # Configure logging
31
  logging.basicConfig(level=logging.INFO)
32
- logger = logging.getLogger(__name__)
33
 
34
- # Device detection and optimization
35
- def get_best_device():
36
- return "cuda" if torch.cuda.is_available() else "cpu"
37
 
38
- DEVICE = get_best_device()
39
- MAX_WORKERS = 1 if DEVICE == "cpu" else 2
 
 
40
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Global model instance
43
- tts_model = None
44
- model_loading = False
45
 
46
- # Pydantic models
47
- class TTSRequest(BaseModel):
48
- text: str = Field(..., min_length=1, max_length=5000)
49
- reference_text: str = Field(..., min_length=1, max_length=1000)
50
- reference_audio_path: Optional[str] = None
51
- output_format: str = Field(default="wav")
52
- speed: float = Field(default=1.0, ge=0.5, le=2.0)
53
 
54
- class StreamingRequest(BaseModel):
55
- text: str = Field(..., min_length=1, max_length=5000)
56
- reference_text: str = Field(..., min_length=1, max_length=1000)
57
- reference_audio_path: str
58
- speed: float = Field(default=1.0, ge=0.5, le=2.0)
59
- chunk_size: int = Field(default=2048, ge=512, le=8192)
60
-
61
- class TTSResponse(BaseModel):
62
- success: bool
63
- audio_url: Optional[str] = None
64
- message: Optional[str] = None
65
- processing_time: Optional[float] = None
66
- audio_duration: Optional[float] = None
67
-
68
- class HealthResponse(BaseModel):
69
- status: str
70
- model_loaded: bool
71
- device: str
72
- memory_usage: Dict[str, float]
73
- disk_usage: Dict[str, float]
74
- streaming_supported: bool = True
75
-
76
- def load_tts_model():
77
- global tts_model, model_loading
78
-
79
- if tts_model is not None or model_loading:
80
- return
81
-
82
- model_loading = True
83
- try:
84
- logger.info(f"Loading NeuTTS Air model on {DEVICE}...")
85
-
86
- # Try to import with fallbacks
87
- try:
88
- from neuttsair.neutts import NeuTTSAir
89
- except ImportError as e:
90
- logger.error(f"Failed to import NeuTTS Air: {e}")
91
- # Try alternative import path
92
- sys.path.insert(0, "/app/neutts-air")
93
- from neuttsair.neutts import NeuTTSAir
94
-
95
- # Use appropriate device with fallback
96
- device = DEVICE
97
  try:
98
- tts_model = NeuTTSAir(
99
- backbone_repo="neuphonic/neutts-air",
100
- backbone_device=device,
101
- codec_repo="neuphonic/neucodec",
102
- codec_device=device
103
- )
104
  except Exception as e:
105
- logger.warning(f"Failed to load on {device}, falling back to CPU: {e}")
106
- tts_model = NeuTTSAir(
107
- backbone_repo="neuphonic/neutts-air",
108
- backbone_device="cpu",
109
- codec_repo="neuphonic/neucodec",
110
- codec_device="cpu"
111
- )
112
-
113
- # Warm up the model
114
- warm_up_model()
115
-
116
- logger.info("NeuTTS Air model loaded successfully!")
117
-
118
- except Exception as e:
119
- logger.error(f"Failed to load model: {str(e)}")
120
- model_loading = False
121
- raise e
122
-
123
- model_loading = False
124
 
125
- def warm_up_model():
126
- """Warm up the model with a short inference"""
127
- try:
128
- if tts_model is None:
129
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- logger.info("Warming up model...")
132
- # Create a temporary warm-up audio file
133
- temp_dir = "temp_audio"
134
- os.makedirs(temp_dir, exist_ok=True)
135
-
136
- # Generate a simple sine wave as warm-up reference
137
- import scipy.io.wavfile as wavfile
138
- warmup_audio_path = os.path.join(temp_dir, "warmup_ref.wav")
139
-
140
- # Create 1 second of 440Hz sine wave
141
- sample_rate = 24000
142
- t = np.linspace(0, 1, sample_rate)
143
- audio_data = 0.3 * np.sin(2 * np.pi * 440 * t)
144
- audio_data = (audio_data * 32767).astype(np.int16)
145
-
146
- wavfile.write(warmup_audio_path, sample_rate, audio_data)
147
-
148
- # Perform warm-up inference
149
- ref_codes = tts_model.encode_reference(warmup_audio_path)
150
- wav = tts_model.infer("Hello, this is a warm-up.", ref_codes, "Hello warm up")
151
-
152
- # Clean up
153
- if os.path.exists(warmup_audio_path):
154
- os.remove(warmup_audio_path)
155
 
156
- logger.info(f"Model warm-up completed! Generated audio length: {len(wav)}")
157
-
158
- except Exception as e:
159
- logger.warning(f"Model warm-up failed: {e}")
160
-
161
- def validate_audio_file(audio_path: str):
162
- """
163
- Enhanced audio validation with strict NeuTTS Air requirements
164
- Reference: 3-15 seconds of clean, mono audio for optimal results
165
- """
166
- try:
167
- import librosa
168
-
169
- # Check file exists
170
- if not os.path.exists(audio_path):
171
- raise ValueError("Audio file not found")
172
-
173
- # Check file size (roughly 10MB limit)
174
- file_size = os.path.getsize(audio_path) / (1024 * 1024) # MB
175
- if file_size > 10:
176
- raise ValueError(f"Audio file too large: {file_size:.1f}MB. Maximum 10MB allowed.")
177
-
178
- # Load and validate audio properties
179
- audio_data, sample_rate = librosa.load(audio_path, sr=None, mono=False)
180
- audio_duration = librosa.get_duration(y=audio_data, sr=sample_rate)
181
-
182
- # Enhanced validation rules based on NeuTTS Air specifications
183
- if audio_duration < 3 or audio_duration > 15:
184
- raise ValueError(f"Audio duration ({audio_duration:.1f}s) must be between 3-15 seconds for optimal voice cloning")
185
-
186
- if len(audio_data.shape) > 1 and audio_data.shape[0] > 1:
187
- logger.warning("Stereo audio detected. For best results, use mono audio")
188
- # Convert to mono by averaging channels
189
- audio_data = np.mean(audio_data, axis=0)
190
-
191
- if sample_rate < 16000 or sample_rate > 44100:
192
- logger.warning(f"Sample rate {sample_rate}Hz should ideally be between 16-44kHz")
193
-
194
- # Check for sufficient audio quality (basic RMS check)
195
- rms = np.sqrt(np.mean(audio_data**2))
196
- if rms < 0.01: # Too quiet
197
- raise ValueError("Audio signal is too quiet. Please use a clearer recording.")
198
-
199
- logger.info(f"Audio validation passed: {audio_duration:.1f}s, {sample_rate}Hz")
200
- return audio_duration
201
-
202
- except Exception as e:
203
- logger.error(f"Audio validation failed: {str(e)}")
204
- raise ValueError(f"Invalid audio file: {str(e)}")
205
-
206
- def intelligent_text_chunking(text: str) -> List[str]:
207
- """
208
- Intelligent text chunking for optimal streaming
209
- Splits text into meaningful chunks for sequential processing
210
- """
211
- # Clean and normalize text
212
- text = re.sub(r'\s+', ' ', text.strip())
213
-
214
- # First, split by sentences (., !, ?)
215
- sentences = re.split(r'(?<=[.!?])\s+', text)
216
-
217
- chunks = []
218
- for sentence in sentences:
219
- sentence = sentence.strip()
220
- if not sentence:
221
- continue
222
 
223
- # If sentence is too long, split by clauses (commas, semicolons)
224
- if len(sentence) > 100:
225
- clauses = re.split(r'(?<=[,;:])\s+', sentence)
226
- for clause in clauses:
227
- clause = clause.strip()
228
- if clause:
229
- # If clause is still long, split by length
230
- if len(clause) > 80:
231
- words = clause.split()
232
- current_chunk = []
233
- current_length = 0
234
-
235
- for word in words:
236
- if current_length + len(word) + 1 > 80 and current_chunk:
237
- chunks.append(' '.join(current_chunk))
238
- current_chunk = [word]
239
- current_length = len(word)
240
- else:
241
- current_chunk.append(word)
242
- current_length += len(word) + 1
243
-
244
- if current_chunk:
245
- chunks.append(' '.join(current_chunk))
246
- else:
247
- chunks.append(clause)
248
- else:
249
- chunks.append(sentence)
250
-
251
- # Ensure we have at least one chunk
252
- if not chunks:
253
- chunks = [text]
254
-
255
- logger.info(f"Split text into {len(chunks)} chunks for streaming")
256
- return chunks
257
 
258
- async def generate_chunk_audio(chunk_text: str, ref_codes: Any, reference_text: str, speed: float) -> np.ndarray:
259
- """Generate audio for a single text chunk asynchronously"""
260
  loop = asyncio.get_event_loop()
261
  return await loop.run_in_executor(
262
  tts_executor,
263
- tts_model.infer,
264
- chunk_text, ref_codes, reference_text
265
  )
266
 
267
- async def convert_chunk_to_mp3(audio_chunk: np.ndarray) -> bytes:
268
- """Convert audio chunk to MP3 format asynchronously"""
269
- loop = asyncio.get_event_loop()
270
-
271
- def _convert():
272
- mp3_buffer = io.BytesIO()
273
- sf.write(mp3_buffer, audio_chunk, 24000, format='mp3')
274
- return mp3_buffer.getvalue()
275
-
276
- return await loop.run_in_executor(tts_executor, _convert)
277
-
278
- def generate_silent_mp3_header(duration_ms: int = 100) -> bytes:
279
- """Generate a short silent MP3 header for immediate playback"""
280
- silent_audio = np.zeros(int(24000 * duration_ms / 1000)) # 100ms of silence
281
- mp3_buffer = io.BytesIO()
282
- sf.write(mp3_buffer, silent_audio, 24000, format='mp3')
283
- return mp3_buffer.getvalue()
284
-
285
- async def true_realtime_generator(
286
- text: str,
287
- ref_codes: Any,
288
- reference_text: str,
289
- speed: float = 1.0
290
- ) -> Generator[bytes, None, None]:
291
- """
292
- TRUE real-time streaming generator
293
- Processes text line-by-line and streams MP3 chunks immediately
294
- """
295
- start_time = time.time()
296
-
297
  try:
298
- logger.info("Starting TRUE real-time streaming generation...")
299
-
300
- # Step 1: Send MP3 header for immediate browser playback
301
- header_data = generate_silent_mp3_header()
302
- yield header_data
303
- logger.info("Sent MP3 header for immediate playback")
304
-
305
- # Step 2: Intelligent text chunking
306
- text_chunks = intelligent_text_chunking(text)
307
- total_chunks = len(text_chunks)
308
- logger.info(f"Processing {total_chunks} text chunks sequentially")
309
-
310
- # Step 3: Process each chunk in sequence with immediate streaming
311
- successful_chunks = 0
312
- for chunk_index, chunk_text in enumerate(text_chunks, 1):
313
- if not chunk_text.strip():
314
- continue
315
-
316
- chunk_start_time = time.time()
317
- logger.info(f"Processing chunk {chunk_index}/{total_chunks}: '{chunk_text[:50]}...'")
318
-
319
- try:
320
- # Generate audio for this specific chunk
321
- chunk_audio = await generate_chunk_audio(chunk_text, ref_codes, reference_text, speed)
322
-
323
- # Convert to MP3 immediately
324
- mp3_data = await convert_chunk_to_mp3(chunk_audio)
325
-
326
- # Stream the MP3 chunk immediately
327
- yield mp3_data
328
-
329
- chunk_processing_time = time.time() - chunk_start_time
330
- successful_chunks += 1
331
-
332
- logger.info(f"✓ Streamed chunk {chunk_index}/{total_chunks} in {chunk_processing_time:.2f}s, size: {len(mp3_data)} bytes")
333
-
334
- # Small delay to ensure smooth streaming (optional)
335
- await asyncio.sleep(0.01)
336
-
337
- except Exception as chunk_error:
338
- logger.error(f"✗ Failed to process chunk {chunk_index}: {chunk_error}")
339
- # Continue with next chunk instead of failing entirely
340
- continue
341
-
342
- total_processing_time = time.time() - start_time
343
- logger.info(f"TRUE real-time streaming completed: {successful_chunks}/{total_chunks} chunks in {total_processing_time:.2f}s")
344
-
345
  except Exception as e:
346
- logger.error(f"TRUE real-time streaming generator failed: {e}")
347
- raise
 
 
348
 
349
  @asynccontextmanager
350
  async def lifespan(app: FastAPI):
351
- """Modern lifespan management"""
352
  try:
353
- load_tts_model()
354
- logger.info(f"✅ NeuTTS Air model loaded on {DEVICE}")
355
  except Exception as e:
356
- logger.error(f" Model loading failed: {e}")
357
- raise
358
- yield
359
- # Cleanup
 
 
 
 
 
360
  tts_executor.shutdown(wait=False)
361
- # Clean up temporary files
362
- await cleanup_audio_files()
363
 
 
364
  app = FastAPI(
365
- title="NeuTTS Air API - Enhanced",
366
- description="High-quality on-device Text-to-Speech with instant voice cloning and TRUE real-time streaming",
367
- version="2.1.0",
368
- docs_url="/docs",
369
  lifespan=lifespan
370
  )
371
 
372
- # CORS middleware
373
  app.add_middleware(
374
  CORSMiddleware,
375
  allow_origins=["*"],
376
- allow_credentials=True,
377
  allow_methods=["*"],
378
  allow_headers=["*"],
379
  )
380
 
381
- async def run_tts_async(text: str, ref_codes: Any, reference_text: str, speed: float = 1.0):
382
- """Offload blocking TTS call to thread pool"""
383
- loop = asyncio.get_event_loop()
384
- return await loop.run_in_executor(
385
- tts_executor,
386
- tts_model.infer,
387
- text, ref_codes, reference_text
388
- )
389
-
390
- def encode_reference_async(audio_path: str):
391
- """Encode reference audio in thread pool"""
392
- loop = asyncio.get_event_loop()
393
- return loop.run_in_executor(
394
- tts_executor,
395
- tts_model.encode_reference,
396
- audio_path
397
- )
398
 
399
  @app.get("/")
400
  async def root():
 
 
 
 
 
 
 
 
401
  return {
402
- "message": "Enhanced NeuTTS Air API with TRUE Real-time Streaming!",
403
  "status": "healthy",
404
- "version": "2.1.0",
405
- "features": [
406
- "voice_cloning",
407
- "true_realtime_streaming",
408
- "line_by_line_processing",
409
- "multiple_formats",
410
- "production_ready"
411
- ]
 
 
 
412
  }
413
 
414
- @app.get("/health")
415
- async def health_check():
416
- """Enhanced health check endpoint"""
417
- try:
418
- memory = psutil.virtual_memory()
419
- disk = psutil.disk_usage('/')
420
-
421
- return HealthResponse(
422
- status="healthy",
423
- model_loaded=tts_model is not None,
424
- device=DEVICE,
425
- memory_usage={
426
- "total_gb": round(memory.total / (1024**3), 2),
427
- "available_gb": round(memory.available / (1024**3), 2),
428
- "used_percent": round(memory.percent, 2)
429
- },
430
- disk_usage={
431
- "total_gb": round(disk.total / (1024**3), 2),
432
- "free_gb": round(disk.free / (1024**3), 2),
433
- "used_percent": round(disk.percent, 2)
434
- }
435
- )
436
- except Exception as e:
437
- return HealthResponse(
438
- status="degraded",
439
- model_loaded=tts_model is not None,
440
- device=DEVICE,
441
- memory_usage={"error": str(e)},
442
- disk_usage={"error": str(e)}
443
- )
444
 
445
- @app.post("/synthesize")
446
- async def synthesize_speech(
447
- reference_text: str = Form(..., min_length=1, max_length=1000),
448
- text: str = Form(..., min_length=1, max_length=5000),
449
- reference_audio: UploadFile = File(...),
450
- output_format: str = Form("wav"),
451
- speed: float = Form(1.0)
452
  ):
453
  """
454
- Standard synthesis endpoint with audio validation and multiple output formats
 
455
  """
 
 
 
 
 
456
  start_time = time.time()
457
 
458
- if tts_model is None:
459
- raise HTTPException(status_code=503, detail="Model not loaded yet")
460
-
461
- temp_ref_path = None
462
  try:
463
- # Save uploaded file temporarily
464
- temp_dir = "temp_audio"
465
- os.makedirs(temp_dir, exist_ok=True)
466
-
467
- file_extension = os.path.splitext(reference_audio.filename)[1] or ".wav"
468
- temp_ref_path = os.path.join(temp_dir, f"ref_{int(time.time())}{file_extension}")
469
-
470
- async with aiofiles.open(temp_ref_path, 'wb') as out_file:
471
- content = await reference_audio.read()
472
- await out_file.write(content)
473
-
474
- # Enhanced audio validation
475
- audio_duration = validate_audio_file(temp_ref_path)
476
 
477
- # Perform TTS
478
- logger.info(f"Starting synthesis for text: {text[:50]}...")
 
 
 
 
479
 
480
- # Encode reference and generate speech asynchronously
481
- ref_codes = await encode_reference_async(temp_ref_path)
482
- wav = await run_tts_async(text, ref_codes, reference_text, speed)
 
 
 
 
483
 
484
  processing_time = time.time() - start_time
485
- output_audio_duration = len(wav) / 24000
486
-
487
- logger.info(f"Synthesis completed in {processing_time:.2f}s")
488
-
489
- # Handle different output formats
490
- if output_format.lower() in ["mp3", "flac"]:
491
- audio_buffer = io.BytesIO()
492
- if output_format.lower() == "mp3":
493
- sf.write(audio_buffer, wav, 24000, format='mp3')
494
- media_type = "audio/mpeg"
495
- else:
496
- sf.write(audio_buffer, wav, 24000, format='flac')
497
- media_type = "audio/flac"
498
-
499
- audio_buffer.seek(0)
500
-
501
- return Response(
502
- content=audio_buffer.read(),
503
- media_type=media_type,
504
- headers={
505
- "Content-Disposition": f"attachment; filename=cloned_speech.{output_format}",
506
- "X-Processing-Time": str(round(processing_time, 2)),
507
- "X-Audio-Duration": str(round(output_audio_duration, 2))
508
- }
509
- )
510
- else:
511
- # Default WAV format
512
- output_dir = "generated_audio"
513
- os.makedirs(output_dir, exist_ok=True)
514
- output_filename = f"output_{int(time.time())}.wav"
515
- output_path = os.path.join(output_dir, output_filename)
516
-
517
- sf.write(output_path, wav, 24000)
518
-
519
- return TTSResponse(
520
- success=True,
521
- audio_url=f"/audio/{output_filename}",
522
- message="Speech synthesized successfully",
523
- processing_time=round(processing_time, 2),
524
- audio_duration=round(output_audio_duration, 2)
525
- )
526
-
527
- except ValueError as e:
528
- raise HTTPException(status_code=400, detail=str(e))
529
- except Exception as e:
530
- logger.error(f"Synthesis error: {str(e)}")
531
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
532
-
533
- finally:
534
- # Clean up temporary file
535
- if temp_ref_path and os.path.exists(temp_ref_path):
536
- try:
537
- os.remove(temp_ref_path)
538
- except:
539
- pass
540
-
541
- @app.post("/synthesize/true-realtime")
542
- async def true_realtime_synthesis(request: StreamingRequest):
543
- """
544
- TRUE real-time streaming endpoint - processes text line-by-line and streams immediately
545
- First audio chunk delivered in 2-3 seconds even for long texts
546
- """
547
- if tts_model is None:
548
- raise HTTPException(status_code=503, detail="Model not loaded yet")
549
-
550
- try:
551
- # Validate reference audio exists and meets requirements
552
- if not os.path.exists(request.reference_audio_path):
553
- raise HTTPException(status_code=400, detail="Reference audio path not found")
554
-
555
- validate_audio_file(request.reference_audio_path)
556
-
557
- # Encode reference asynchronously (this happens once at the start)
558
- ref_codes = await encode_reference_async(request.reference_audio_path)
559
-
560
- start_time = time.time()
561
-
562
- return StreamingResponse(
563
- true_realtime_generator(
564
- text=request.text,
565
- ref_codes=ref_codes,
566
- reference_text=request.reference_text,
567
- speed=request.speed
568
- ),
569
- media_type="audio/mpeg",
570
  headers={
571
- "Content-Disposition": "attachment; filename=realtime_speech.mp3",
572
- "Transfer-Encoding": "chunked",
573
- "X-Streaming-Type": "true-realtime-line-by-line",
574
- "X-First-Chunk-ETA": "2-3s",
575
- "Cache-Control": "no-cache",
576
- "X-Start-Time": str(start_time)
577
  }
578
  )
579
-
580
  except Exception as e:
581
- logger.error(f"TRUE real-time streaming error: {e}")
582
- raise HTTPException(status_code=500, detail=f"TRUE real-time streaming failed: {str(e)}")
 
 
 
 
583
 
584
- # Legacy streaming endpoint (fake streaming) for backward compatibility
585
  @app.post("/synthesize/stream")
586
- async def legacy_stream_synthesis(request: StreamingRequest):
 
 
 
 
 
587
  """
588
- Legacy streaming endpoint (fake streaming) - for backward compatibility
589
- Use /synthesize/true-realtime for real streaming
590
  """
591
- if tts_model is None:
592
- raise HTTPException(status_code=503, detail="Model not loaded yet")
 
 
 
593
 
594
- try:
595
- if not os.path.exists(request.reference_audio_path):
596
- raise HTTPException(status_code=400, detail="Reference audio path not found")
597
-
598
- validate_audio_file(request.reference_audio_path)
599
- ref_codes = await encode_reference_async(request.reference_audio_path)
600
-
601
- # Legacy approach: generate complete audio then chunk
602
- def legacy_stream_generator():
603
- wav = tts_model.infer(request.text, ref_codes, request.reference_text)
604
- audio_buffer = io.BytesIO()
605
- sf.write(audio_buffer, wav, 24000, format='mp3')
606
- audio_data = audio_buffer.getvalue()
 
 
 
 
 
607
 
608
- # Stream in chunks
609
- chunk_size = request.chunk_size
610
- for i in range(0, len(audio_data), chunk_size):
611
- yield audio_data[i:i + chunk_size]
612
-
613
- return StreamingResponse(
614
- legacy_stream_generator(),
615
- media_type="audio/mpeg",
616
- headers={
617
- "Content-Disposition": "attachment; filename=legacy_stream.mp3",
618
- "X-Streaming-Type": "legacy-chunked"
619
- }
620
- )
621
-
622
- except Exception as e:
623
- logger.error(f"Legacy streaming error: {e}")
624
- raise HTTPException(status_code=500, detail=f"Legacy streaming failed: {str(e)}")
625
 
626
  @app.get("/audio/{filename}")
627
- async def get_audio_file(filename: str):
628
- """Serve generated audio files"""
629
- file_path = os.path.join("generated_audio", filename)
630
-
631
  if not os.path.exists(file_path):
632
  raise HTTPException(status_code=404, detail="Audio file not found")
633
 
634
- return FileResponse(
635
- file_path,
636
- media_type="audio/wav",
637
- filename=f"cloned_speech_{filename}"
638
- )
639
-
640
- @app.post("/synthesize-with-url")
641
- async def synthesize_with_url(request: TTSRequest):
642
- """
643
- Enhanced synthesis with URL support and multiple formats
644
- """
645
- start_time = time.time()
646
-
647
- if tts_model is None:
648
- raise HTTPException(status_code=503, detail="Model not loaded yet")
649
-
650
- if not request.reference_audio_path or not os.path.exists(request.reference_audio_path):
651
- raise HTTPException(status_code=400, detail="Reference audio path not found")
652
-
653
- try:
654
- validate_audio_file(request.reference_audio_path)
655
-
656
- # Perform TTS asynchronously
657
- logger.info(f"Starting synthesis for text: {request.text[:50]}...")
658
-
659
- ref_codes = await encode_reference_async(request.reference_audio_path)
660
- wav = await run_tts_async(request.text, ref_codes, request.reference_text, request.speed)
661
-
662
- processing_time = time.time() - start_time
663
- audio_duration = len(wav) / 24000
664
-
665
- # Handle output format
666
- if request.output_format.lower() in ["mp3", "flac"]:
667
- audio_buffer = io.BytesIO()
668
- if request.output_format.lower() == "mp3":
669
- sf.write(audio_buffer, wav, 24000, format='mp3')
670
- media_type = "audio/mpeg"
671
- else:
672
- sf.write(audio_buffer, wav, 24000, format='flac')
673
- media_type = "audio/flac"
674
-
675
- audio_buffer.seek(0)
676
-
677
- return Response(
678
- content=audio_buffer.read(),
679
- media_type=media_type,
680
- headers={
681
- "Content-Disposition": f"attachment; filename=cloned_speech.{request.output_format}",
682
- "X-Processing-Time": str(round(processing_time, 2)),
683
- "X-Audio-Duration": str(round(audio_duration, 2))
684
- }
685
- )
686
- else:
687
- # Save as WAV
688
- output_dir = "generated_audio"
689
- os.makedirs(output_dir, exist_ok=True)
690
- output_filename = f"output_{int(time.time())}.wav"
691
- output_path = os.path.join(output_dir, output_filename)
692
-
693
- sf.write(output_path, wav, 24000)
694
-
695
- return TTSResponse(
696
- success=True,
697
- audio_url=f"/audio/{output_filename}",
698
- message="Speech synthesized successfully",
699
- processing_time=round(processing_time, 2),
700
- audio_duration=round(audio_duration, 2)
701
- )
702
-
703
- except Exception as e:
704
- logger.error(f"Synthesis error: {str(e)}")
705
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
706
-
707
- @app.delete("/cleanup")
708
- async def cleanup_audio_files():
709
- """Enhanced cleanup with efficient file management"""
710
- try:
711
- output_dir = "generated_audio"
712
- temp_dir = "temp_audio"
713
-
714
- deleted_count = 0
715
- current_time = time.time()
716
-
717
- # Clean generated audio
718
- if os.path.exists(output_dir):
719
- for filename in os.listdir(output_dir):
720
- file_path = os.path.join(output_dir, filename)
721
- if os.path.isfile(file_path):
722
- file_age = current_time - os.path.getctime(file_path)
723
- if file_age > 3600: # 1 hour
724
- os.remove(file_path)
725
- deleted_count += 1
726
-
727
- # Clean temp audio (shorter retention)
728
- if os.path.exists(temp_dir):
729
- for filename in os.listdir(temp_dir):
730
- file_path = os.path.join(temp_dir, filename)
731
- if os.path.isfile(file_path):
732
- file_age = current_time - os.path.getctime(file_path)
733
- if file_age > 1800: # 30 minutes for temp files
734
- os.remove(file_path)
735
- deleted_count += 1
736
-
737
- # Force garbage collection
738
- gc.collect()
739
-
740
- return {
741
- "message": f"Cleaned up {deleted_count} files",
742
- "memory_cleaned": "true",
743
- "next_cleanup": "in_1_hour"
744
- }
745
-
746
- except Exception as e:
747
- raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
748
-
749
- # GET endpoint for simple synthesis
750
- @app.get("/synthesize")
751
- async def synthesize_speech_get(
752
- text: str = Query(..., min_length=1, max_length=5000),
753
- reference_text: str = Query(..., min_length=1, max_length=1000),
754
- reference_audio_path: str = Query(...),
755
- output_format: str = Query("wav"),
756
- speed: float = Query(1.0)
757
- ):
758
- """GET endpoint for speech synthesis"""
759
- request = TTSRequest(
760
- text=text,
761
- reference_text=reference_text,
762
- reference_audio_path=reference_audio_path,
763
- output_format=output_format,
764
- speed=speed
765
  )
766
- return await synthesize_with_url(request)
767
-
768
- if __name__ == "__main__":
769
- import uvicorn
770
- uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
 
1
  import os
2
+ import io
3
+ import asyncio
4
  import time
5
+ import shutil
 
6
  import numpy as np
 
 
 
 
 
 
 
 
 
 
7
  import psutil
 
8
  import soundfile as sf
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from typing import Optional, Generator
11
  from contextlib import asynccontextmanager
12
+ import logging
13
+ import aiofiles
14
+ import torch
15
+ from fastapi import FastAPI, HTTPException, Response, StreamingResponse, UploadFile, File, Form
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel, Field
18
 
19
+ # Ensure the cloned neutts-air repository is in the path
20
+ import sys
21
+ sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
22
+ from neuttsair.neutts import NeuTTSAir
 
 
 
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger("NeuTTS-API")
27
 
28
+ # --- Configuration & Utility Functions ---
 
 
29
 
30
+ # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
31
+ DEVICE = "cpu"
32
+ # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
33
+ MAX_WORKERS = 2
34
  tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
35
+ SAMPLE_RATE = 24000
36
+ CLEANUP_THRESHOLD = 3600 # 1 hour in seconds
37
+ TEMP_AUDIO_DIR = "temp_audio"
38
+ GENERATED_AUDIO_DIR = "generated_audio"
39
+ os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
40
+ os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
41
+
42
+ class TTSRequestModel(BaseModel):
43
+ """Model for non-file inputs to synthesis and streaming."""
44
+ text: str = Field(..., min_length=1, max_length=1000)
45
+ speed: float = Field(default=1.0, ge=0.5, le=2.0)
46
+ output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
47
 
48
+ # --- Model Wrapper and Logic ---
 
 
49
 
50
+ class NeuTTSWrapper:
51
+ def __init__(self, device: str = "cpu"):
52
+ self.tts_model = None
53
+ self.device = device
54
+ self.load_model()
 
 
55
 
56
+ def load_model(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
+ logger.info(f"Loading NeuTTSAir model on device: {self.device}")
59
+ # Ensure we respect the CPU configuration
60
+ self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
61
+ logger.info("✅ NeuTTSAir model loaded successfully.")
 
 
62
  except Exception as e:
63
+ logger.error(f" Model loading failed: {e}")
64
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
67
+ """Converts NumPy audio array to streamable bytes in the specified format."""
68
+ audio_buffer = io.BytesIO()
69
+ try:
70
+ sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
71
+ except Exception as e:
72
+ logger.error(f"Failed to write audio data to format {audio_format}: {e}")
73
+ raise
74
+ audio_buffer.seek(0)
75
+ return audio_buffer.read()
76
+
77
+ def _split_text_into_chunks(self, text: str) -> list[str]:
78
+ """Simple sentence splitting for streaming (can be enhanced with regex)."""
79
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
80
+ if not sentences:
81
+ sentences = [text.strip()]
82
+ return sentences
83
+
84
+ def generate_speech_blocking(self, text: str, ref_audio_path: str) -> np.ndarray:
85
+ """Blocking synthesis for standard endpoint."""
86
+ # 1. Load reference
87
+ reference_audio, sr = sf.read(ref_audio_path)
88
+ if sr != SAMPLE_RATE:
89
+ # Simple check/resize logic required if sample rate mismatch occurs
90
+ pass
91
+
92
+ # 2. Encode reference
93
+ ref_s = self.tts_model.encode_reference(reference_audio)
94
+
95
+ # 3. Infer full text
96
+ with torch.no_grad():
97
+ audio = self.tts_model.infer(text, ref_s, speed=1.0)
98
+ return audio.cpu().numpy()
99
+
100
+ def stream_speech_blocking(self, text: str, ref_audio_path: str, speed: float, audio_format: str) -> Generator[bytes, None, None]:
101
+ """Sentence-by-Sentence Streaming (Blocking)."""
102
+ logger.info(f"Starting streaming synthesis for text length: {len(text)}")
103
+
104
+ # 1. Load reference audio (ONLY ONCE)
105
+ reference_audio, sr = sf.read(ref_audio_path)
106
+
107
+ # 2. Encode reference (ONLY ONCE)
108
+ ref_s = self.tts_model.encode_reference(reference_audio)
109
+
110
+ # 3. Split text
111
+ sentences = self._split_text_into_chunks(text)
112
+
113
+ # 4. Stream chunks
114
+ for i, sentence in enumerate(sentences):
115
+ if not sentence.strip():
116
+ continue
117
 
118
+ logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # Infer sentence
121
+ with torch.no_grad():
122
+ audio_chunk = self.tts_model.infer(sentence, ref_s, speed=speed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # Convert and yield
125
+ yield self._convert_to_streamable_format(audio_chunk.cpu().numpy(), audio_format)
126
+
127
+ logger.info("Streaming synthesis complete.")
128
+
129
+ # --- Asynchronous Offloading ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ async def run_blocking_task_async(func, *args, **kwargs):
132
+ """Offloads a blocking function call to the ThreadPoolExecutor."""
133
  loop = asyncio.get_event_loop()
134
  return await loop.run_in_executor(
135
  tts_executor,
136
+ lambda: func(*args, **kwargs)
 
137
  )
138
 
139
+ async def save_upload_file_async(upload_file: UploadFile) -> str:
140
+ """Asynchronously saves the UploadFile to disk."""
141
+ temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  try:
143
+ # Use asyncio to read the file chunks in a non-blocking manner
144
+ async with aiofiles.open(temp_filename, 'wb') as out_file:
145
+ while content := await upload_file.read(1024 * 1024):
146
+ await out_file.write(content)
147
+ return temp_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  except Exception as e:
149
+ logger.error(f"Error saving file: {e}")
150
+ raise HTTPException(status_code=500, detail="Could not save reference audio file")
151
+
152
+ # --- FastAPI Lifespan Manager (Kokoro Feature) ---
153
 
154
  @asynccontextmanager
155
  async def lifespan(app: FastAPI):
156
+ """Modern lifespan management: initialize model on startup, shutdown executor."""
157
  try:
158
+ app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
 
159
  except Exception as e:
160
+ logger.error(f"Fatal startup error: {e}")
161
+ # Terminate the application if the model can't load
162
+ tts_executor.shutdown(wait=False)
163
+ raise RuntimeError("Model initialization failed.")
164
+
165
+ yield # Application serves requests
166
+
167
+ # Shutdown
168
+ logger.info("Shutting down ThreadPoolExecutor.")
169
  tts_executor.shutdown(wait=False)
 
 
170
 
171
+ # --- FastAPI Application Setup ---
172
  app = FastAPI(
173
+ title="NeuTTS Air Instant Cloning API",
174
+ version="2.0.0-PROD-ENHANCED",
175
+ docs_url="/docs",
 
176
  lifespan=lifespan
177
  )
178
 
 
179
  app.add_middleware(
180
  CORSMiddleware,
181
  allow_origins=["*"],
 
182
  allow_methods=["*"],
183
  allow_headers=["*"],
184
  )
185
 
186
+ # --- New Endpoints and Enhancements ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  @app.get("/")
189
  async def root():
190
+ return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"}
191
+
192
+ @app.get("/health")
193
+ async def health_check():
194
+ """Enhanced health check (Kokoro Feature + Original Metrics)"""
195
+ mem = psutil.virtual_memory()
196
+ disk = psutil.disk_usage('/')
197
+
198
  return {
 
199
  "status": "healthy",
200
+ "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
201
+ "device": DEVICE,
202
+ "concurrency_limit": MAX_WORKERS,
203
+ "memory_usage": {
204
+ "total_gb": round(mem.total / (1024**3), 2),
205
+ "used_percent": mem.percent
206
+ },
207
+ "disk_usage": {
208
+ "total_gb": round(disk.total / (1024**3), 2),
209
+ "used_percent": disk.percent
210
+ }
211
  }
212
 
213
+ @app.delete("/cleanup")
214
+ async def cleanup_files():
215
+ """Maintenance endpoint to remove old generated and temporary files."""
216
+ await run_blocking_task_async(cleanup_files_blocking)
217
+ return {"message": "Cleanup initiated successfully."}
218
+
219
+ def cleanup_files_blocking():
220
+ """Blocking file cleanup logic (original NeuTTS feature)."""
221
+ now = time.time()
222
+ deleted_count = 0
223
+
224
+ for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]:
225
+ for filename in os.listdir(directory):
226
+ filepath = os.path.join(directory, filename)
227
+ if os.path.isfile(filepath):
228
+ try:
229
+ # Original cleanup logic: delete if older than CLEANUP_THRESHOLD
230
+ if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD:
231
+ os.remove(filepath)
232
+ deleted_count += 1
233
+ except Exception as e:
234
+ logger.warning(f"Failed to delete {filepath}: {e}")
235
+
236
+ logger.info(f"Cleanup completed: {deleted_count} files removed.")
237
+ return deleted_count
238
+
239
+
240
+ # --- Core Synthesis Endpoints ---
 
 
241
 
242
+ @app.post("/synthesize", response_class=Response)
243
+ async def text_to_speech(
244
+ text: str = Form(...),
245
+ speed: float = Form(1.0, ge=0.5, le=2.0),
246
+ output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
247
+ reference_audio: UploadFile = File(...)
 
248
  ):
249
  """
250
+ Standard blocking TTS endpoint with Multi-Format Output (Kokoro Feature).
251
+ Uses ThreadPoolExecutor for non-blocking API responsiveness.
252
  """
253
+ if not hasattr(app.state, 'tts_wrapper'):
254
+ raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
255
+
256
+ # 1. Asynchronously save reference audio
257
+ temp_ref_path = await save_upload_file_async(reference_audio)
258
  start_time = time.time()
259
 
 
 
 
 
260
  try:
261
+ # 2. Offload the ENTIRE blocking process (encode + infer) to a thread
262
+ audio_data = await run_blocking_task_async(
263
+ app.state.tts_wrapper.generate_speech_blocking,
264
+ text,
265
+ temp_ref_path
266
+ )
 
 
 
 
 
 
 
267
 
268
+ # 3. Convert to requested format (Blocking, but usually fast)
269
+ audio_bytes = await run_blocking_task_async(
270
+ app.state.tts_wrapper._convert_to_streamable_format,
271
+ audio_data,
272
+ output_format
273
+ )
274
 
275
+ # 4. Save to disk (Original NeuTTS requirement)
276
+ audio_filename = f"tts_{time.time()}.{output_format}"
277
+ final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename)
278
+ # We perform the file write operation in a blocking manner inside the thread pool.
279
+ await run_blocking_task_async(
280
+ lambda: open(final_path, 'wb').write(audio_bytes)
281
+ )
282
 
283
  processing_time = time.time() - start_time
284
+ audio_duration = len(audio_data) / SAMPLE_RATE
285
+
286
+ return Response(
287
+ content=audio_bytes,
288
+ media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  headers={
290
+ "Content-Disposition": f"attachment; filename={audio_filename}",
291
+ "X-Processing-Time": f"{processing_time:.2f}s",
292
+ "X-Audio-Duration": f"{audio_duration:.2f}s"
 
 
 
293
  }
294
  )
295
+
296
  except Exception as e:
297
+ logger.error(f"Synthesis error: {e}")
298
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
299
+ finally:
300
+ # 5. Clean up the temporary reference file
301
+ if os.path.exists(temp_ref_path):
302
+ os.unlink(temp_ref_path)
303
 
 
304
  @app.post("/synthesize/stream")
305
+ async def stream_text_to_speech_cloning(
306
+ text: str = Form(..., min_length=1, max_length=5000), # Increased limit for streaming
307
+ speed: float = Form(1.0, ge=0.5, le=2.0),
308
+ output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"), # MP3 is best for streaming
309
+ reference_audio: UploadFile = File(...)
310
+ ):
311
  """
312
+ Sentence-by-Sentence Streaming Endpoint (Kokoro Feature adaptation).
313
+ Performs encoding once, then synthesizes and streams chunks.
314
  """
315
+ if not hasattr(app.state, 'tts_wrapper'):
316
+ raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
317
+
318
+ # 1. Asynchronously save reference audio (non-blocking)
319
+ temp_ref_path = await save_upload_file_async(reference_audio)
320
 
321
+ # 2. Define the generator function, which will run in the thread pool implicitly
322
+ def stream_generator():
323
+ try:
324
+ # The entire streaming process runs blocking inside the thread pool
325
+ for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking(
326
+ text,
327
+ temp_ref_path,
328
+ speed,
329
+ output_format
330
+ ):
331
+ yield chunk_bytes
332
+ except Exception as e:
333
+ logger.error(f"Streaming generator error: {e}")
334
+ # Raise an exception if necessary, though it might break the stream
335
+ finally:
336
+ # 3. Cleanup the temporary reference file after the stream is done
337
+ if os.path.exists(temp_ref_path):
338
+ os.unlink(temp_ref_path)
339
 
340
+ # The StreamingResponse handles the transfer encoding and chunking
341
+ return StreamingResponse(
342
+ stream_generator(),
343
+ media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
344
+ headers={
345
+ "Content-Disposition": "attachment; filename=tts_live_stream.mp3",
346
+ "Transfer-Encoding": "chunked",
347
+ "Cache-Control": "no-cache"
348
+ }
349
+ )
 
 
 
 
 
 
 
350
 
351
  @app.get("/audio/{filename}")
352
+ async def get_audio(filename: str):
353
+ """Original NeuTTS feature to serve generated audio files."""
354
+ file_path = os.path.join(GENERATED_AUDIO_DIR, filename)
 
355
  if not os.path.exists(file_path):
356
  raise HTTPException(status_code=404, detail="Audio file not found")
357
 
358
+ return Response(
359
+ content=open(file_path, "rb").read(),
360
+ media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection
361
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  )