Rajhuggingface4253 commited on
Commit
eb0ee66
·
verified ·
1 Parent(s): c2ab408

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -356
app.py CHANGED
@@ -9,12 +9,12 @@ import asyncio
9
  import subprocess
10
  import io
11
  from contextlib import asynccontextmanager
12
- from typing import Optional, Dict, Any, AsyncGenerator
13
  from uuid import uuid4
14
  from pathlib import Path
15
 
16
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Request
17
- from fastapi.responses import JSONResponse, StreamingResponse, Response
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel, Field
20
  import psutil
@@ -31,168 +31,134 @@ logging.basicConfig(
31
  )
32
  logger = logging.getLogger(__name__)
33
 
34
- # Configuration
35
  class Config:
36
  MAX_TEXT_LENGTH = 1000
37
  MIN_AUDIO_DURATION = 2
38
  MAX_AUDIO_DURATION = 30
39
  SAMPLE_RATE = 24000
40
  REFERENCE_SAMPLE_RATE = 16000
41
- CHUNK_SIZE = 4096 # For streaming
42
- MAX_CONCURRENT_REQUESTS = 3
43
- REQUEST_TIMEOUT = 120
 
 
44
 
45
  config = Config()
46
 
47
- # Global model instance with async support
48
  tts_model = None
49
  model_loading = False
50
  active_requests = 0
51
- request_semaphore = asyncio.Semaphore(config.MAX_CONCURRENT_REQUESTS)
52
 
53
- # In-memory audio cache to avoid disk usage
54
  audio_cache = {}
55
- CACHE_MAX_SIZE = 50 # Max cached audio files
56
- CACHE_CLEANUP_INTERVAL = 300 # 5 minutes
57
 
58
- class AudioCache:
59
- """In-memory audio cache to avoid disk usage"""
60
-
61
- def __init__(self, max_size: int = 50):
62
- self.cache = {}
63
- self.max_size = max_size
64
- self.access_order = []
65
-
66
- async def store_audio(self, audio_id: str, audio_data: np.ndarray, sample_rate: int):
67
- """Store audio in memory"""
68
- if len(self.cache) >= self.max_size:
69
- await self._remove_oldest()
70
-
71
- self.cache[audio_id] = {
72
- 'audio': audio_data,
73
- 'sample_rate': sample_rate,
74
- 'created_at': time.time(),
75
- 'accessed_at': time.time()
76
- }
77
- self.access_order.append(audio_id)
78
-
79
- async def get_audio(self, audio_id: str) -> Optional[Dict]:
80
- """Retrieve audio from memory"""
81
- if audio_id in self.cache:
82
- self.cache[audio_id]['accessed_at'] = time.time()
83
- # Move to end of access order
84
- if audio_id in self.access_order:
85
- self.access_order.remove(audio_id)
86
- self.access_order.append(audio_id)
87
- return self.cache[audio_id]
88
- return None
89
-
90
- async def _remove_oldest(self):
91
- """Remove least recently used audio"""
92
- if self.access_order:
93
- oldest_id = self.access_order.pop(0)
94
- if oldest_id in self.cache:
95
- del self.cache[oldest_id]
96
- logger.debug(f"Removed cached audio: {oldest_id}")
97
-
98
- # Initialize cache
99
- audio_cache = AudioCache(max_size=CACHE_MAX_SIZE)
100
-
101
- class AudioStreamProcessor:
102
- """Process audio in memory without disk usage"""
103
 
104
  @staticmethod
105
- async def convert_audio_to_wav_memory(upload_file: UploadFile) -> tuple[bytes, float]:
106
- """Convert uploaded audio to WAV format in memory"""
 
 
107
  try:
108
- # Read uploaded file into memory
109
  file_content = await upload_file.read()
110
 
111
- # Create temporary in-memory files
112
- input_buffer = io.BytesIO(file_content)
113
- output_buffer = io.BytesIO()
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Save input to temporary file (minimal disk usage for ffmpeg)
116
- temp_input_path = f"/tmp/input_{uuid4().hex}{Path(upload_file.filename).suffix}"
117
- temp_output_path = f"/tmp/output_{uuid4().hex}.wav"
 
 
118
 
 
 
 
119
  try:
120
- # Write input to temp file
121
- async with aiofiles.open(temp_input_path, 'wb') as f:
122
- await f.write(file_content)
123
-
124
- # Convert using ffmpeg
125
- cmd = [
126
- 'ffmpeg', '-i', temp_input_path,
127
- '-ac', '1',
128
- '-ar', str(config.REFERENCE_SAMPLE_RATE),
129
- '-acodec', 'pcm_s16le',
130
- '-y', temp_output_path
131
- ]
132
-
133
- process = await asyncio.create_subprocess_exec(
134
- *cmd,
135
- stdout=asyncio.subprocess.PIPE,
136
- stderr=asyncio.subprocess.PIPE
137
- )
138
-
139
- stdout, stderr = await process.communicate()
140
-
141
- if process.returncode != 0:
142
- raise Exception(f"FFmpeg failed: {stderr.decode()}")
143
-
144
- # Read converted file into memory
145
- async with aiofiles.open(temp_output_path, 'rb') as f:
146
- wav_data = await f.read()
147
-
148
- # Get duration
149
- duration = await AudioStreamProcessor.get_audio_duration_memory(wav_data)
150
-
151
- return wav_data, duration
152
-
153
- finally:
154
- # Cleanup temp files
155
- for temp_file in [temp_input_path, temp_output_path]:
156
- if os.path.exists(temp_file):
157
- try:
158
- os.remove(temp_file)
159
- except:
160
- pass
161
-
162
  except Exception as e:
163
- logger.error(f"Audio conversion failed: {e}")
 
 
 
 
 
 
164
  raise
165
-
 
 
 
 
 
 
 
166
  @staticmethod
167
- async def get_audio_duration_memory(audio_data: bytes) -> float:
168
- """Get audio duration from in-memory WAV data"""
169
- try:
170
- # Use soundfile with BytesIO
171
- with sf.SoundFile(io.BytesIO(audio_data)) as audio_file:
172
- return len(audio_file) / audio_file.samplerate
173
- except Exception as e:
174
- logger.warning(f"SoundFile duration failed: {e}, using librosa")
175
- # Fallback to librosa
176
- import librosa
177
- audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=None)
178
- return len(audio_array) / sr
179
-
 
 
 
 
 
 
 
 
 
 
 
180
  @staticmethod
181
- async def validate_audio_duration(duration: float):
182
- """Validate audio duration"""
183
- if duration < config.MIN_AUDIO_DURATION:
184
- raise HTTPException(
185
- status_code=400,
186
- detail=f"Audio too short: {duration:.1f}s (minimum {config.MIN_AUDIO_DURATION}s)"
187
- )
188
- if duration > config.MAX_AUDIO_DURATION:
189
- raise HTTPException(
190
- status_code=400,
191
- detail=f"Audio too long: {duration:.1f}s (maximum {config.MAX_AUDIO_DURATION}s)"
192
- )
193
 
194
  async def load_tts_model():
195
- """Load TTS model asynchronously"""
196
  global tts_model, model_loading
197
 
198
  if tts_model is not None or model_loading:
@@ -200,21 +166,16 @@ async def load_tts_model():
200
 
201
  model_loading = True
202
  try:
203
- logger.info("Loading NeuTTS Air model...")
204
 
205
  # Clear memory before loading
206
  gc.collect()
207
  if torch.cuda.is_available():
208
  torch.cuda.empty_cache()
209
 
210
- # Import model
211
- try:
212
- from neuttsair.neutts import NeuTTSAir
213
- except ImportError as e:
214
- logger.error(f"Failed to import NeuTTS Air: {e}")
215
- raise
216
 
217
- # Initialize model
218
  tts_model = NeuTTSAir(
219
  backbone_repo="neuphonic/neutts-air",
220
  backbone_device="cpu",
@@ -222,39 +183,39 @@ async def load_tts_model():
222
  codec_device="cpu"
223
  )
224
 
225
- logger.info("NeuTTS Air model loaded successfully!")
226
 
227
  except Exception as e:
228
- logger.error(f"Failed to load model: {str(e)}")
229
  raise e
230
  finally:
231
  model_loading = False
232
 
233
  @asynccontextmanager
234
  async def lifespan(app: FastAPI):
235
- """Lifespan manager with async startup/shutdown"""
236
  # Startup
237
- logger.info("🚀 Starting NeuTTS Air Streaming API")
238
 
239
  # Load model in background
240
  asyncio.create_task(load_tts_model())
241
 
242
- # Start cache cleanup task
243
- asyncio.create_task(cache_cleanup_task())
244
 
245
  yield
246
 
247
- # Shutdown
248
  logger.info("🛑 Shutting down NeuTTS Air API")
249
  global tts_model
250
  if tts_model is not None:
251
  del tts_model
252
  tts_model = None
253
- gc.collect()
254
 
255
  app = FastAPI(
256
- title="NeuTTS Air Streaming API",
257
- description="High-quality on-device TTS with streaming and no disk usage",
258
  version="2.0.0",
259
  lifespan=lifespan
260
  )
@@ -272,7 +233,6 @@ app.add_middleware(
272
  class TTSRequest(BaseModel):
273
  text: str = Field(..., min_length=1, max_length=1000)
274
  reference_text: str = Field(..., min_length=1, max_length=500)
275
- reference_audio_path: Optional[str] = None
276
 
277
  class TTSResponse(BaseModel):
278
  success: bool
@@ -287,191 +247,209 @@ class HealthResponse(BaseModel):
287
  model_loaded: bool
288
  active_requests: int
289
  cache_size: int
290
- memory_usage: Dict[str, float]
291
-
292
- # Async middleware for request limiting
293
- @app.middleware("http")
294
- async def limit_concurrent_requests(request: Request, call_next):
295
- global active_requests
296
-
297
- if active_requests >= config.MAX_CONCURRENT_REQUESTS:
298
- return JSONResponse(
299
- status_code=429,
300
- content={"detail": "Too many concurrent requests"}
301
- )
302
-
303
- async with request_semaphore:
304
- active_requests += 1
305
- try:
306
- start_time = time.time()
307
- response = await call_next(request)
308
- process_time = time.time() - start_time
309
- logger.info(f"{request.method} {request.url.path} completed in {process_time:.2f}s")
310
- return response
311
- finally:
312
- active_requests -= 1
313
 
314
  @app.get("/")
315
  async def root():
316
  return {
317
- "message": "NeuTTS Air Streaming API",
318
  "status": "healthy",
319
- "features": ["streaming", "no_disk_usage", "async", "in_memory_cache"],
320
  "model_loaded": tts_model is not None,
321
- "active_requests": active_requests
 
322
  }
323
 
324
  @app.get("/health")
325
  async def health_check():
326
- """Health check with memory usage"""
327
  try:
328
  memory = psutil.virtual_memory()
 
329
 
330
  return HealthResponse(
331
  status="healthy",
332
  model_loaded=tts_model is not None,
333
  active_requests=active_requests,
334
- cache_size=len(audio_cache.cache),
335
- memory_usage={
336
- "total_gb": round(memory.total / (1024**3), 2),
337
- "available_gb": round(memory.available / (1024**3), 2),
338
- "used_percent": round(memory.percent, 2)
339
- }
340
  )
341
  except Exception as e:
342
  return HealthResponse(
343
  status="degraded",
344
  model_loaded=tts_model is not None,
345
  active_requests=active_requests,
346
- cache_size=len(audio_cache.cache),
347
- memory_usage={"error": str(e)}
 
348
  )
349
 
350
  @app.post("/synthesize", response_model=TTSResponse)
351
  async def synthesize_speech(
 
352
  reference_text: str = Form(...),
353
  text: str = Form(...),
354
  reference_audio: UploadFile = File(...)
355
  ):
356
  """
357
- Synthesize speech with streaming support and no disk usage
358
  """
 
359
  start_time = time.time()
360
  request_id = str(uuid4())[:8]
 
361
 
362
- logger.info(f"[{request_id}] Starting streaming synthesis")
363
-
364
- if tts_model is None:
365
- raise HTTPException(status_code=503, detail="Model not loaded yet")
366
-
367
- # Validate inputs
368
- if not reference_text.strip() or not text.strip():
369
- raise HTTPException(status_code=400, detail="Text fields cannot be empty")
370
 
371
  try:
372
- # Convert audio to WAV in memory
373
- wav_data, audio_duration = await AudioStreamProcessor.convert_audio_to_wav_memory(reference_audio)
374
- await AudioStreamProcessor.validate_audio_duration(audio_duration)
375
 
376
- logger.info(f"[{request_id}] Audio validated: {audio_duration:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
- # Create temporary file for model processing (minimal disk usage)
379
- temp_ref_path = f"/tmp/ref_{request_id}.wav"
380
- try:
381
- async with aiofiles.open(temp_ref_path, 'wb') as f:
382
- await f.write(wav_data)
383
-
384
- # Perform TTS
385
- logger.info(f"[{request_id}] Synthesizing: '{text[:50]}...'")
386
-
387
- # Encode reference and generate speech
388
- ref_codes = tts_model.encode_reference(temp_ref_path)
389
- wav_output = tts_model.infer(text, ref_codes, reference_text)
390
-
391
- # Generate audio ID for caching
392
- audio_id = f"audio_{request_id}"
393
-
394
- # Store in memory cache
395
- await audio_cache.store_audio(audio_id, wav_output, config.SAMPLE_RATE)
396
-
397
- processing_time = time.time() - start_time
398
- output_duration = len(wav_output) / config.SAMPLE_RATE
399
-
400
- logger.info(f"[{request_id}] Synthesis completed in {processing_time:.2f}s")
401
-
402
- return TTSResponse(
403
- success=True,
404
- audio_id=audio_id,
405
- message="Speech synthesized successfully",
406
- processing_time=round(processing_time, 2),
407
- audio_duration=round(output_duration, 2),
408
- stream_url=f"/stream/{audio_id}"
409
- )
410
-
411
- finally:
412
- # Cleanup temp file
413
- if os.path.exists(temp_ref_path):
414
- try:
415
- os.remove(temp_ref_path)
416
- except:
417
- pass
418
-
419
  except HTTPException:
420
  raise
421
  except Exception as e:
422
  logger.error(f"[{request_id}] Synthesis error: {str(e)}")
423
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  @app.get("/stream/{audio_id}")
426
  async def stream_audio(audio_id: str):
427
  """
428
- Stream audio directly from memory cache
429
  """
430
- # Get audio from cache
431
- cached_audio = await audio_cache.get_audio(audio_id)
432
- if not cached_audio:
433
- raise HTTPException(status_code=404, detail="Audio not found or expired")
434
-
435
- audio_data = cached_audio['audio']
436
- sample_rate = cached_audio['sample_rate']
437
 
438
- # Convert numpy array to WAV bytes in memory
439
  wav_buffer = io.BytesIO()
440
- sf.write(wav_buffer, audio_data, sample_rate, format='WAV')
441
  wav_bytes = wav_buffer.getvalue()
442
 
443
- # Create async generator for streaming
444
- async def generate_audio_stream():
445
  chunk_size = config.CHUNK_SIZE
446
  for i in range(0, len(wav_bytes), chunk_size):
447
  yield wav_bytes[i:i + chunk_size]
448
- await asyncio.sleep(0.001) # Small delay for proper streaming
449
 
450
  return StreamingResponse(
451
- generate_audio_stream(),
452
  media_type="audio/wav",
453
  headers={
454
  "Content-Disposition": f"attachment; filename=speech_{audio_id}.wav",
455
- "Cache-Control": "no-cache",
456
- "Content-Length": str(len(wav_bytes))
457
  }
458
  )
459
 
460
  @app.get("/download/{audio_id}")
461
  async def download_audio(audio_id: str):
462
  """
463
- Download audio as complete file
464
  """
465
- cached_audio = await audio_cache.get_audio(audio_id)
466
- if not cached_audio:
467
- raise HTTPException(status_code=404, detail="Audio not found or expired")
468
-
469
- audio_data = cached_audio['audio']
470
- sample_rate = cached_audio['sample_rate']
471
 
472
- # Convert to WAV in memory
473
  wav_buffer = io.BytesIO()
474
- sf.write(wav_buffer, audio_data, sample_rate, format='WAV')
475
  wav_bytes = wav_buffer.getvalue()
476
 
477
  return Response(
@@ -483,113 +461,71 @@ async def download_audio(audio_id: str):
483
  }
484
  )
485
 
486
- @app.post("/synthesize-and-stream")
487
- async def synthesize_and_stream(
488
- reference_text: str = Form(...),
489
- text: str = Form(...),
490
- reference_audio: UploadFile = File(...)
491
- ):
492
- """
493
- Real-time synthesis and streaming in one endpoint
494
- """
495
- start_time = time.time()
496
-
497
- if tts_model is None:
498
- raise HTTPException(status_code=503, detail="Model not loaded yet")
499
-
500
- try:
501
- # Convert audio to WAV in memory
502
- wav_data, audio_duration = await AudioStreamProcessor.convert_audio_to_wav_memory(reference_audio)
503
- await AudioStreamProcessor.validate_audio_duration(audio_duration)
504
-
505
- # Create temporary file for model processing
506
- temp_ref_path = f"/tmp/ref_stream_{uuid4().hex}.wav"
507
- try:
508
- async with aiofiles.open(temp_ref_path, 'wb') as f:
509
- await f.write(wav_data)
510
-
511
- # Perform TTS
512
- ref_codes = tts_model.encode_reference(temp_ref_path)
513
- wav_output = tts_model.infer(text, ref_codes, reference_text)
514
-
515
- processing_time = time.time() - start_time
516
- logger.info(f"Real-time synthesis completed in {processing_time:.2f}s")
517
-
518
- # Convert to WAV bytes
519
- wav_buffer = io.BytesIO()
520
- sf.write(wav_buffer, wav_output, config.SAMPLE_RATE, format='WAV')
521
- wav_bytes = wav_buffer.getvalue()
522
-
523
- # Stream directly
524
- async def generate_stream():
525
- chunk_size = config.CHUNK_SIZE
526
- for i in range(0, len(wav_bytes), chunk_size):
527
- yield wav_bytes[i:i + chunk_size]
528
- await asyncio.sleep(0.001)
529
-
530
- return StreamingResponse(
531
- generate_stream(),
532
- media_type="audio/wav",
533
- headers={
534
- "Content-Disposition": "attachment; filename=speech_stream.wav",
535
- "Cache-Control": "no-cache",
536
- "X-Processing-Time": f"{processing_time:.2f}"
537
- }
538
- )
539
-
540
- finally:
541
- if os.path.exists(temp_ref_path):
542
- try:
543
- os.remove(temp_ref_path)
544
- except:
545
- pass
546
-
547
- except Exception as e:
548
- logger.error(f"Stream synthesis error: {str(e)}")
549
- raise HTTPException(status_code=500, detail=f"Stream synthesis failed: {str(e)}")
550
-
551
  @app.delete("/cache/{audio_id}")
552
  async def clear_cached_audio(audio_id: str):
553
  """Clear specific audio from cache"""
554
- if audio_id in audio_cache.cache:
555
- del audio_cache.cache[audio_id]
556
- if audio_id in audio_cache.access_order:
557
- audio_cache.access_order.remove(audio_id)
558
  return {"message": f"Audio {audio_id} cleared from cache"}
559
  else:
560
  raise HTTPException(status_code=404, detail="Audio not found in cache")
561
 
562
  @app.delete("/cache")
563
  async def clear_all_cache():
564
- """Clear all audio cache"""
565
- cache_size = len(audio_cache.cache)
566
- audio_cache.cache.clear()
567
- audio_cache.access_order.clear()
 
568
  return {"message": f"Cleared all {cache_size} cached audio files"}
569
 
570
- async def cache_cleanup_task():
571
- """Background task to clean up old cache entries"""
 
 
 
 
 
 
 
 
 
572
  while True:
573
- await asyncio.sleep(CACHE_CLEANUP_INTERVAL)
 
574
  try:
 
575
  current_time = time.time()
576
- expired_ids = []
577
-
578
- for audio_id, data in audio_cache.cache.items():
579
- if current_time - data['accessed_at'] > 3600: # 1 hour
580
- expired_ids.append(audio_id)
581
 
582
  for audio_id in expired_ids:
583
- if audio_id in audio_cache.cache:
584
- del audio_cache.cache[audio_id]
585
- if audio_id in audio_cache.access_order:
586
- audio_cache.access_order.remove(audio_id)
587
 
588
  if expired_ids:
589
- logger.info(f"Cache cleanup removed {len(expired_ids)} expired entries")
590
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  except Exception as e:
592
- logger.error(f"Cache cleanup error: {e}")
593
 
594
  if __name__ == "__main__":
595
  import uvicorn
 
9
  import subprocess
10
  import io
11
  from contextlib import asynccontextmanager
12
+ from typing import Optional, Dict, Any
13
  from uuid import uuid4
14
  from pathlib import Path
15
 
16
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
17
+ from fastapi.responses import JSONResponse, StreamingResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel, Field
20
  import psutil
 
31
  )
32
  logger = logging.getLogger(__name__)
33
 
34
+ # Configuration - OPTIMIZED FOR MEMORY
35
  class Config:
36
  MAX_TEXT_LENGTH = 1000
37
  MIN_AUDIO_DURATION = 2
38
  MAX_AUDIO_DURATION = 30
39
  SAMPLE_RATE = 24000
40
  REFERENCE_SAMPLE_RATE = 16000
41
+ CHUNK_SIZE = 8192
42
+ MAX_CONCURRENT_REQUESTS = 2
43
+ CACHE_MAX_FILES = 5 # Very small cache
44
+ CACHE_MAX_SIZE_MB = 5 # Only 5MB cache
45
+ TEMP_FILE_TIMEOUT = 300 # 5 minutes
46
 
47
  config = Config()
48
 
49
+ # Global model instance - SINGLE LOAD
50
  tts_model = None
51
  model_loading = False
52
  active_requests = 0
 
53
 
54
+ # Small in-memory cache for recent requests
55
  audio_cache = {}
56
+ cache_access_order = []
 
57
 
58
+ class MemoryOptimizedProcessor:
59
+ """Handles audio processing with minimal memory footprint"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @staticmethod
62
+ async def process_reference_audio(upload_file: UploadFile) -> str:
63
+ """Process reference audio and return temp file path - CLEANED AFTER USE"""
64
+ temp_ref_path = f"/tmp/ref_{uuid4().hex}.wav"
65
+
66
  try:
67
+ # Read file content
68
  file_content = await upload_file.read()
69
 
70
+ # Write to temp input file
71
+ temp_input = f"/tmp/in_{uuid4().hex}{Path(upload_file.filename).suffix}"
72
+ async with aiofiles.open(temp_input, 'wb') as f:
73
+ await f.write(file_content)
74
+
75
+ # Convert to WAV using ffmpeg
76
+ cmd = [
77
+ 'ffmpeg', '-i', temp_input,
78
+ '-ac', '1',
79
+ '-ar', str(config.REFERENCE_SAMPLE_RATE),
80
+ '-acodec', 'pcm_s16le',
81
+ '-y', temp_ref_path
82
+ ]
83
 
84
+ process = await asyncio.create_subprocess_exec(
85
+ *cmd,
86
+ stdout=asyncio.subprocess.PIPE,
87
+ stderr=asyncio.subprocess.PIPE
88
+ )
89
 
90
+ await process.communicate()
91
+
92
+ # Validate audio duration
93
  try:
94
+ with sf.SoundFile(temp_ref_path) as audio_file:
95
+ duration = len(audio_file) / audio_file.samplerate
96
+ if duration < config.MIN_AUDIO_DURATION:
97
+ raise ValueError(f"Audio too short: {duration:.1f}s")
98
+ if duration > config.MAX_AUDIO_DURATION:
99
+ raise ValueError(f"Audio too long: {duration:.1f}s")
100
+ except Exception as e:
101
+ raise ValueError(f"Invalid audio file: {str(e)}")
102
+
103
+ return temp_ref_path
104
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  except Exception as e:
106
+ # Cleanup on error
107
+ for temp_file in [temp_input, temp_ref_path]:
108
+ if os.path.exists(temp_file):
109
+ try:
110
+ os.remove(temp_file)
111
+ except:
112
+ pass
113
  raise
114
+ finally:
115
+ # Always cleanup input temp file
116
+ if 'temp_input' in locals() and os.path.exists(temp_input):
117
+ try:
118
+ os.remove(temp_input)
119
+ except:
120
+ pass
121
+
122
  @staticmethod
123
+ def add_to_cache(audio_id: str, audio_data: np.ndarray):
124
+ """Add audio to cache with size limits"""
125
+ # Calculate approximate size
126
+ audio_size_mb = (audio_data.nbytes / 1024 / 1024)
127
+
128
+ # Remove oldest items if cache too large
129
+ while (len(audio_cache) >= config.CACHE_MAX_FILES or
130
+ sum((data['audio'].nbytes / 1024 / 1024) for data in audio_cache.values()) > config.CACHE_MAX_SIZE_MB):
131
+ if cache_access_order:
132
+ oldest_id = cache_access_order.pop(0)
133
+ if oldest_id in audio_cache:
134
+ del audio_cache[oldest_id]
135
+
136
+ # Add to cache
137
+ audio_cache[audio_id] = {
138
+ 'audio': audio_data,
139
+ 'timestamp': time.time(),
140
+ 'size_mb': audio_size_mb
141
+ }
142
+ cache_access_order.append(audio_id)
143
+
144
+ logger.info(f"Cache: {len(audio_cache)} files, "
145
+ f"{sum(d['size_mb'] for d in audio_cache.values()):.2f}MB")
146
+
147
  @staticmethod
148
+ def get_from_cache(audio_id: str) -> Optional[np.ndarray]:
149
+ """Get audio from cache and update access time"""
150
+ if audio_id in audio_cache:
151
+ # Move to end of access order (most recently used)
152
+ if audio_id in cache_access_order:
153
+ cache_access_order.remove(audio_id)
154
+ cache_access_order.append(audio_id)
155
+
156
+ audio_cache[audio_id]['timestamp'] = time.time()
157
+ return audio_cache[audio_id]['audio']
158
+ return None
 
159
 
160
  async def load_tts_model():
161
+ """Load TTS model once with memory optimization"""
162
  global tts_model, model_loading
163
 
164
  if tts_model is not None or model_loading:
 
166
 
167
  model_loading = True
168
  try:
169
+ logger.info("🔄 Loading NeuTTS Air model...")
170
 
171
  # Clear memory before loading
172
  gc.collect()
173
  if torch.cuda.is_available():
174
  torch.cuda.empty_cache()
175
 
176
+ # Import and initialize model
177
+ from neuttsair.neutts import NeuTTSAir
 
 
 
 
178
 
 
179
  tts_model = NeuTTSAir(
180
  backbone_repo="neuphonic/neutts-air",
181
  backbone_device="cpu",
 
183
  codec_device="cpu"
184
  )
185
 
186
+ logger.info("NeuTTS Air model loaded successfully!")
187
 
188
  except Exception as e:
189
+ logger.error(f"Failed to load model: {str(e)}")
190
  raise e
191
  finally:
192
  model_loading = False
193
 
194
  @asynccontextmanager
195
  async def lifespan(app: FastAPI):
196
+ """Lifespan manager with efficient startup/shutdown"""
197
  # Startup
198
+ logger.info("🚀 Starting NeuTTS Air API")
199
 
200
  # Load model in background
201
  asyncio.create_task(load_tts_model())
202
 
203
+ # Start background cleanup task
204
+ asyncio.create_task(background_cleanup())
205
 
206
  yield
207
 
208
+ # Shutdown - cleanup
209
  logger.info("🛑 Shutting down NeuTTS Air API")
210
  global tts_model
211
  if tts_model is not None:
212
  del tts_model
213
  tts_model = None
214
+ gc.collect()
215
 
216
  app = FastAPI(
217
+ title="NeuTTS Air - Optimized API",
218
+ description="Memory-efficient TTS with streaming",
219
  version="2.0.0",
220
  lifespan=lifespan
221
  )
 
233
  class TTSRequest(BaseModel):
234
  text: str = Field(..., min_length=1, max_length=1000)
235
  reference_text: str = Field(..., min_length=1, max_length=500)
 
236
 
237
  class TTSResponse(BaseModel):
238
  success: bool
 
247
  model_loaded: bool
248
  active_requests: int
249
  cache_size: int
250
+ cache_memory_mb: float
251
+ system_memory_gb: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  @app.get("/")
254
  async def root():
255
  return {
256
+ "message": "NeuTTS Air - Memory Optimized API",
257
  "status": "healthy",
 
258
  "model_loaded": tts_model is not None,
259
+ "cache_size": len(audio_cache),
260
+ "memory_optimized": True
261
  }
262
 
263
  @app.get("/health")
264
  async def health_check():
265
+ """Health check with memory monitoring"""
266
  try:
267
  memory = psutil.virtual_memory()
268
+ cache_memory = sum(data['size_mb'] for data in audio_cache.values())
269
 
270
  return HealthResponse(
271
  status="healthy",
272
  model_loaded=tts_model is not None,
273
  active_requests=active_requests,
274
+ cache_size=len(audio_cache),
275
+ cache_memory_mb=round(cache_memory, 2),
276
+ system_memory_gb=round(memory.used / (1024**3), 2)
 
 
 
277
  )
278
  except Exception as e:
279
  return HealthResponse(
280
  status="degraded",
281
  model_loaded=tts_model is not None,
282
  active_requests=active_requests,
283
+ cache_size=len(audio_cache),
284
+ cache_memory_mb=0,
285
+ system_memory_gb=0
286
  )
287
 
288
  @app.post("/synthesize", response_model=TTSResponse)
289
  async def synthesize_speech(
290
+ background_tasks: BackgroundTasks,
291
  reference_text: str = Form(...),
292
  text: str = Form(...),
293
  reference_audio: UploadFile = File(...)
294
  ):
295
  """
296
+ Efficient synthesis with streaming and minimal memory usage
297
  """
298
+ global active_requests
299
  start_time = time.time()
300
  request_id = str(uuid4())[:8]
301
+ temp_ref_path = None
302
 
303
+ active_requests += 1
 
 
 
 
 
 
 
304
 
305
  try:
306
+ if tts_model is None:
307
+ raise HTTPException(status_code=503, detail="Model loading, please wait")
 
308
 
309
+ # Validate inputs
310
+ if not reference_text.strip() or not text.strip():
311
+ raise HTTPException(status_code=400, detail="Text fields cannot be empty")
312
+
313
+ logger.info(f"[{request_id}] Starting synthesis")
314
+
315
+ # Process reference audio - creates temp file
316
+ temp_ref_path = await MemoryOptimizedProcessor.process_reference_audio(reference_audio)
317
+
318
+ # Perform TTS (this is where most memory is used)
319
+ ref_codes = tts_model.encode_reference(temp_ref_path)
320
+ wav_output = tts_model.infer(text, ref_codes, reference_text)
321
+
322
+ # Generate audio ID and add to small cache
323
+ audio_id = f"audio_{request_id}"
324
+ MemoryOptimizedProcessor.add_to_cache(audio_id, wav_output)
325
+
326
+ processing_time = time.time() - start_time
327
+ audio_duration = len(wav_output) / config.SAMPLE_RATE
328
+
329
+ logger.info(f"[{request_id}] Synthesis completed: {processing_time:.2f}s")
330
+
331
+ return TTSResponse(
332
+ success=True,
333
+ audio_id=audio_id,
334
+ message="Speech synthesized successfully",
335
+ processing_time=round(processing_time, 2),
336
+ audio_duration=round(audio_duration, 2),
337
+ stream_url=f"/stream/{audio_id}"
338
+ )
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  except HTTPException:
341
  raise
342
  except Exception as e:
343
  logger.error(f"[{request_id}] Synthesis error: {str(e)}")
344
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
345
+ finally:
346
+ active_requests -= 1
347
+ # Schedule cleanup of temp reference file
348
+ if temp_ref_path and os.path.exists(temp_ref_path):
349
+ background_tasks.add_task(cleanup_temp_file, temp_ref_path)
350
+
351
+ @app.post("/synthesize-and-stream")
352
+ async def synthesize_and_stream(
353
+ reference_text: str = Form(...),
354
+ text: str = Form(...),
355
+ reference_audio: UploadFile = File(...)
356
+ ):
357
+ """
358
+ Direct synthesis and streaming - no caching, minimal memory
359
+ """
360
+ global active_requests
361
+ start_time = time.time()
362
+ temp_ref_path = None
363
+
364
+ active_requests += 1
365
+
366
+ try:
367
+ if tts_model is None:
368
+ raise HTTPException(status_code=503, detail="Model loading, please wait")
369
+
370
+ # Process reference audio
371
+ temp_ref_path = await MemoryOptimizedProcessor.process_reference_audio(reference_audio)
372
+
373
+ # Perform TTS
374
+ ref_codes = tts_model.encode_reference(temp_ref_path)
375
+ wav_output = tts_model.infer(text, ref_codes, reference_text)
376
+
377
+ # Convert to WAV bytes in memory
378
+ wav_buffer = io.BytesIO()
379
+ sf.write(wav_buffer, wav_output, config.SAMPLE_RATE, format='WAV')
380
+ wav_bytes = wav_buffer.getvalue()
381
+
382
+ processing_time = time.time() - start_time
383
+
384
+ logger.info(f"Stream synthesis completed: {processing_time:.2f}s")
385
+
386
+ # Stream directly without storing
387
+ async def generate_stream():
388
+ chunk_size = config.CHUNK_SIZE
389
+ for i in range(0, len(wav_bytes), chunk_size):
390
+ yield wav_bytes[i:i + chunk_size]
391
+ await asyncio.sleep(0.001) # Small delay for smooth streaming
392
+
393
+ return StreamingResponse(
394
+ generate_stream(),
395
+ media_type="audio/wav",
396
+ headers={
397
+ "Content-Disposition": "attachment; filename=speech_stream.wav",
398
+ "X-Processing-Time": f"{processing_time:.2f}",
399
+ "Cache-Control": "no-store" # Prevent caching
400
+ }
401
+ )
402
+
403
+ except Exception as e:
404
+ logger.error(f"Stream synthesis error: {str(e)}")
405
+ raise HTTPException(status_code=500, detail=f"Stream synthesis failed: {str(e)}")
406
+ finally:
407
+ active_requests -= 1
408
+ if temp_ref_path and os.path.exists(temp_ref_path):
409
+ asyncio.create_task(cleanup_temp_file(temp_ref_path))
410
 
411
  @app.get("/stream/{audio_id}")
412
  async def stream_audio(audio_id: str):
413
  """
414
+ Stream audio from small cache
415
  """
416
+ # Get from cache
417
+ audio_data = MemoryOptimizedProcessor.get_from_cache(audio_id)
418
+ if audio_data is None:
419
+ raise HTTPException(status_code=404, detail="Audio not found in cache")
 
 
 
420
 
421
+ # Convert to WAV bytes
422
  wav_buffer = io.BytesIO()
423
+ sf.write(wav_buffer, audio_data, config.SAMPLE_RATE, format='WAV')
424
  wav_bytes = wav_buffer.getvalue()
425
 
426
+ # Stream with chunks
427
+ async def generate_stream():
428
  chunk_size = config.CHUNK_SIZE
429
  for i in range(0, len(wav_bytes), chunk_size):
430
  yield wav_bytes[i:i + chunk_size]
431
+ await asyncio.sleep(0.001)
432
 
433
  return StreamingResponse(
434
+ generate_stream(),
435
  media_type="audio/wav",
436
  headers={
437
  "Content-Disposition": f"attachment; filename=speech_{audio_id}.wav",
438
+ "Cache-Control": "no-store"
 
439
  }
440
  )
441
 
442
  @app.get("/download/{audio_id}")
443
  async def download_audio(audio_id: str):
444
  """
445
+ Download audio directly
446
  """
447
+ audio_data = MemoryOptimizedProcessor.get_from_cache(audio_id)
448
+ if audio_data is None:
449
+ raise HTTPException(status_code=404, detail="Audio not found in cache")
 
 
 
450
 
 
451
  wav_buffer = io.BytesIO()
452
+ sf.write(wav_buffer, audio_data, config.SAMPLE_RATE, format='WAV')
453
  wav_bytes = wav_buffer.getvalue()
454
 
455
  return Response(
 
461
  }
462
  )
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  @app.delete("/cache/{audio_id}")
465
  async def clear_cached_audio(audio_id: str):
466
  """Clear specific audio from cache"""
467
+ if audio_id in audio_cache:
468
+ del audio_cache[audio_id]
469
+ if audio_id in cache_access_order:
470
+ cache_access_order.remove(audio_id)
471
  return {"message": f"Audio {audio_id} cleared from cache"}
472
  else:
473
  raise HTTPException(status_code=404, detail="Audio not found in cache")
474
 
475
  @app.delete("/cache")
476
  async def clear_all_cache():
477
+ """Clear all cache"""
478
+ cache_size = len(audio_cache)
479
+ audio_cache.clear()
480
+ cache_access_order.clear()
481
+ gc.collect()
482
  return {"message": f"Cleared all {cache_size} cached audio files"}
483
 
484
+ async def cleanup_temp_file(file_path: str):
485
+ """Cleanup temporary file"""
486
+ try:
487
+ await asyncio.sleep(1) # Small delay to ensure file is not in use
488
+ if os.path.exists(file_path):
489
+ os.remove(file_path)
490
+ except Exception as e:
491
+ logger.warning(f"Could not delete temp file {file_path}: {e}")
492
+
493
+ async def background_cleanup():
494
+ """Background task to clean up old cache entries and temp files"""
495
  while True:
496
+ await asyncio.sleep(300) # Run every 5 minutes
497
+
498
  try:
499
+ # Clean old cache entries (older than 1 hour)
500
  current_time = time.time()
501
+ expired_ids = [
502
+ audio_id for audio_id, data in audio_cache.items()
503
+ if current_time - data['timestamp'] > 3600
504
+ ]
 
505
 
506
  for audio_id in expired_ids:
507
+ if audio_id in audio_cache:
508
+ del audio_cache[audio_id]
509
+ if audio_id in cache_access_order:
510
+ cache_access_order.remove(audio_id)
511
 
512
  if expired_ids:
513
+ logger.info(f"Background cleanup: removed {len(expired_ids)} cache entries")
514
 
515
+ # Clean old temp files in /tmp
516
+ for filename in os.listdir('/tmp'):
517
+ if filename.startswith(('ref_', 'conv_', 'in_')):
518
+ file_path = os.path.join('/tmp', filename)
519
+ try:
520
+ if os.path.isfile(file_path):
521
+ file_age = time.time() - os.path.getctime(file_path)
522
+ if file_age > config.TEMP_FILE_TIMEOUT:
523
+ os.remove(file_path)
524
+ except:
525
+ pass
526
+
527
  except Exception as e:
528
+ logger.error(f"Background cleanup error: {e}")
529
 
530
  if __name__ == "__main__":
531
  import uvicorn