Rajhuggingface4253 commited on
Commit
f8d6527
·
verified ·
1 Parent(s): dc2764b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -27
app.py CHANGED
@@ -13,7 +13,7 @@ from typing import Optional, Dict, Any
13
  from uuid import uuid4
14
  from pathlib import Path
15
 
16
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
17
  from fastapi.responses import JSONResponse, StreamingResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel, Field
@@ -21,7 +21,7 @@ import psutil
21
  import logging
22
  import soundfile as sf
23
 
24
- # Add NeuTTS Air to path
25
  sys.path.insert(0, "/app/neutts-air")
26
 
27
  # Configure logging
@@ -31,18 +31,19 @@ logging.basicConfig(
31
  )
32
  logger = logging.getLogger(__name__)
33
 
34
- # Configuration - OPTIMIZED FOR MEMORY
35
  class Config:
36
  MAX_TEXT_LENGTH = 1000
37
  MIN_AUDIO_DURATION = 2
38
  MAX_AUDIO_DURATION = 30
39
- SAMPLE_RATE = 24000
40
  REFERENCE_SAMPLE_RATE = 16000
41
  CHUNK_SIZE = 8192
42
  MAX_CONCURRENT_REQUESTS = 2
43
  CACHE_MAX_FILES = 5 # Very small cache
44
  CACHE_MAX_SIZE_MB = 5 # Only 5MB cache
45
  TEMP_FILE_TIMEOUT = 300 # 5 minutes
 
46
 
47
  config = Config()
48
 
@@ -51,6 +52,9 @@ tts_model = None
51
  model_loading = False
52
  active_requests = 0
53
 
 
 
 
54
  # Small in-memory cache for recent requests
55
  audio_cache = {}
56
  cache_access_order = []
@@ -62,6 +66,7 @@ class MemoryOptimizedProcessor:
62
  async def process_reference_audio(upload_file: UploadFile) -> str:
63
  """Process reference audio and return temp file path - CLEANED AFTER USE"""
64
  temp_ref_path = f"/tmp/ref_{uuid4().hex}.wav"
 
65
 
66
  try:
67
  # Read file content
@@ -72,7 +77,7 @@ class MemoryOptimizedProcessor:
72
  async with aiofiles.open(temp_input, 'wb') as f:
73
  await f.write(file_content)
74
 
75
- # Convert to WAV using ffmpeg
76
  cmd = [
77
  'ffmpeg', '-i', temp_input,
78
  '-ac', '1',
@@ -87,7 +92,10 @@ class MemoryOptimizedProcessor:
87
  stderr=asyncio.subprocess.PIPE
88
  )
89
 
90
- await process.communicate()
 
 
 
91
 
92
  # Validate audio duration
93
  try:
@@ -102,18 +110,22 @@ class MemoryOptimizedProcessor:
102
 
103
  return temp_ref_path
104
 
105
- except Exception as e:
106
  # Cleanup on error
107
- for temp_file in [temp_input, temp_ref_path]:
108
- if os.path.exists(temp_file):
109
- try:
110
- os.remove(temp_file)
111
- except:
112
- pass
 
 
 
 
113
  raise
114
  finally:
115
  # Always cleanup input temp file
116
- if 'temp_input' in locals() and os.path.exists(temp_input):
117
  try:
118
  os.remove(temp_input)
119
  except:
@@ -176,6 +188,7 @@ async def load_tts_model():
176
  # Import and initialize model
177
  from neuttsair.neutts import NeuTTSAir
178
 
 
179
  tts_model = NeuTTSAir(
180
  backbone_repo="neuphonic/neutts-air",
181
  backbone_device="cpu",
@@ -183,6 +196,20 @@ async def load_tts_model():
183
  codec_device="cpu"
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  logger.info("✅ NeuTTS Air model loaded successfully!")
187
 
188
  except Exception as e:
@@ -209,7 +236,10 @@ async def lifespan(app: FastAPI):
209
  logger.info("🛑 Shutting down NeuTTS Air API")
210
  global tts_model
211
  if tts_model is not None:
212
- del tts_model
 
 
 
213
  tts_model = None
214
  gc.collect()
215
 
@@ -285,6 +315,20 @@ async def health_check():
285
  system_memory_gb=0
286
  )
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  @app.post("/synthesize", response_model=TTSResponse)
289
  async def synthesize_speech(
290
  background_tasks: BackgroundTasks,
@@ -293,15 +337,22 @@ async def synthesize_speech(
293
  reference_audio: UploadFile = File(...)
294
  ):
295
  """
296
- Efficient synthesis with streaming and minimal memory usage
 
297
  """
298
  global active_requests
299
  start_time = time.time()
300
  request_id = str(uuid4())[:8]
301
  temp_ref_path = None
302
-
 
 
 
 
 
 
303
  active_requests += 1
304
-
305
  try:
306
  if tts_model is None:
307
  raise HTTPException(status_code=503, detail="Model loading, please wait")
@@ -315,9 +366,15 @@ async def synthesize_speech(
315
  # Process reference audio - creates temp file
316
  temp_ref_path = await MemoryOptimizedProcessor.process_reference_audio(reference_audio)
317
 
318
- # Perform TTS (this is where most memory is used)
319
- ref_codes = tts_model.encode_reference(temp_ref_path)
320
- wav_output = tts_model.infer(text, ref_codes, reference_text)
 
 
 
 
 
 
321
 
322
  # Generate audio ID and add to small cache
323
  audio_id = f"audio_{request_id}"
@@ -344,6 +401,7 @@ async def synthesize_speech(
344
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
345
  finally:
346
  active_requests -= 1
 
347
  # Schedule cleanup of temp reference file
348
  if temp_ref_path and os.path.exists(temp_ref_path):
349
  background_tasks.add_task(cleanup_temp_file, temp_ref_path)
@@ -355,12 +413,18 @@ async def synthesize_and_stream(
355
  reference_audio: UploadFile = File(...)
356
  ):
357
  """
358
- Direct synthesis and streaming - no caching, minimal memory
 
359
  """
360
  global active_requests
361
  start_time = time.time()
362
  temp_ref_path = None
363
-
 
 
 
 
 
364
  active_requests += 1
365
 
366
  try:
@@ -370,9 +434,12 @@ async def synthesize_and_stream(
370
  # Process reference audio
371
  temp_ref_path = await MemoryOptimizedProcessor.process_reference_audio(reference_audio)
372
 
373
- # Perform TTS
374
- ref_codes = tts_model.encode_reference(temp_ref_path)
375
- wav_output = tts_model.infer(text, ref_codes, reference_text)
 
 
 
376
 
377
  # Convert to WAV bytes in memory
378
  wav_buffer = io.BytesIO()
@@ -383,7 +450,7 @@ async def synthesize_and_stream(
383
 
384
  logger.info(f"Stream synthesis completed: {processing_time:.2f}s")
385
 
386
- # Stream directly without storing
387
  async def generate_stream():
388
  chunk_size = config.CHUNK_SIZE
389
  for i in range(0, len(wav_bytes), chunk_size):
@@ -400,11 +467,14 @@ async def synthesize_and_stream(
400
  }
401
  )
402
 
 
 
403
  except Exception as e:
404
  logger.error(f"Stream synthesis error: {str(e)}")
405
  raise HTTPException(status_code=500, detail=f"Stream synthesis failed: {str(e)}")
406
  finally:
407
  active_requests -= 1
 
408
  if temp_ref_path and os.path.exists(temp_ref_path):
409
  asyncio.create_task(cleanup_temp_file(temp_ref_path))
410
 
 
13
  from uuid import uuid4
14
  from pathlib import Path
15
 
16
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Response
17
  from fastapi.responses import JSONResponse, StreamingResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel, Field
 
21
  import logging
22
  import soundfile as sf
23
 
24
+ # Add NeuTTS Air to path (adjust if needed)
25
  sys.path.insert(0, "/app/neutts-air")
26
 
27
  # Configure logging
 
31
  )
32
  logger = logging.getLogger(__name__)
33
 
34
+ # Configuration - OPTIMIZED FOR MEMORY & CPU
35
  class Config:
36
  MAX_TEXT_LENGTH = 1000
37
  MIN_AUDIO_DURATION = 2
38
  MAX_AUDIO_DURATION = 30
39
+ SAMPLE_RATE = 24000 # Consider 16000 if you need more speed/less CPU
40
  REFERENCE_SAMPLE_RATE = 16000
41
  CHUNK_SIZE = 8192
42
  MAX_CONCURRENT_REQUESTS = 2
43
  CACHE_MAX_FILES = 5 # Very small cache
44
  CACHE_MAX_SIZE_MB = 5 # Only 5MB cache
45
  TEMP_FILE_TIMEOUT = 300 # 5 minutes
46
+ ENABLE_QUANTIZATION = True # Best-effort dynamic quantization
47
 
48
  config = Config()
49
 
 
52
  model_loading = False
53
  active_requests = 0
54
 
55
+ # Concurrency control
56
+ _infer_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENT_REQUESTS)
57
+
58
  # Small in-memory cache for recent requests
59
  audio_cache = {}
60
  cache_access_order = []
 
66
  async def process_reference_audio(upload_file: UploadFile) -> str:
67
  """Process reference audio and return temp file path - CLEANED AFTER USE"""
68
  temp_ref_path = f"/tmp/ref_{uuid4().hex}.wav"
69
+ temp_input = None
70
 
71
  try:
72
  # Read file content
 
77
  async with aiofiles.open(temp_input, 'wb') as f:
78
  await f.write(file_content)
79
 
80
+ # Convert to WAV using ffmpeg (installed in image)
81
  cmd = [
82
  'ffmpeg', '-i', temp_input,
83
  '-ac', '1',
 
92
  stderr=asyncio.subprocess.PIPE
93
  )
94
 
95
+ _, stderr = await process.communicate()
96
+ if process.returncode != 0:
97
+ logger.warning(f"ffmpeg conversion failed: {stderr.decode('utf-8', errors='ignore')}")
98
+ raise ValueError("ffmpeg failed to convert reference audio")
99
 
100
  # Validate audio duration
101
  try:
 
110
 
111
  return temp_ref_path
112
 
113
+ except Exception:
114
  # Cleanup on error
115
+ if temp_input and os.path.exists(temp_input):
116
+ try:
117
+ os.remove(temp_input)
118
+ except:
119
+ pass
120
+ if os.path.exists(temp_ref_path):
121
+ try:
122
+ os.remove(temp_ref_path)
123
+ except:
124
+ pass
125
  raise
126
  finally:
127
  # Always cleanup input temp file
128
+ if temp_input and os.path.exists(temp_input):
129
  try:
130
  os.remove(temp_input)
131
  except:
 
188
  # Import and initialize model
189
  from neuttsair.neutts import NeuTTSAir
190
 
191
+ # Force CPU devices so Hugging Face free tier works
192
  tts_model = NeuTTSAir(
193
  backbone_repo="neuphonic/neutts-air",
194
  backbone_device="cpu",
 
196
  codec_device="cpu"
197
  )
198
 
199
+ # Best-effort: dynamic quantization to speed up CPU inference
200
+ if config.ENABLE_QUANTIZATION:
201
+ try:
202
+ # quantize_dynamic is safe for models with Linear/RNN modules; skip if not compatible
203
+ if isinstance(tts_model, torch.nn.Module):
204
+ tts_model = torch.quantization.quantize_dynamic(
205
+ tts_model, {torch.nn.Linear}, dtype=torch.qint8
206
+ )
207
+ logger.info("✅ Applied dynamic quantization to model (best-effort).")
208
+ else:
209
+ logger.info("Model is not an nn.Module; skipping dynamic quantization.")
210
+ except Exception as e:
211
+ logger.warning(f"Dynamic quantization failed (continuing without it): {e}")
212
+
213
  logger.info("✅ NeuTTS Air model loaded successfully!")
214
 
215
  except Exception as e:
 
236
  logger.info("🛑 Shutting down NeuTTS Air API")
237
  global tts_model
238
  if tts_model is not None:
239
+ try:
240
+ del tts_model
241
+ except:
242
+ pass
243
  tts_model = None
244
  gc.collect()
245
 
 
315
  system_memory_gb=0
316
  )
317
 
318
+ async def _encode_reference_async(temp_ref_path: str):
319
+ """Wrap encode_reference to run off the event loop"""
320
+ def _encode():
321
+ with torch.inference_mode(), torch.no_grad():
322
+ return tts_model.encode_reference(temp_ref_path)
323
+ return await asyncio.to_thread(_encode)
324
+
325
+ async def _infer_async(text: str, ref_codes, reference_text: str):
326
+ """Wrap infer to run off the event loop"""
327
+ def _infer():
328
+ with torch.inference_mode(), torch.no_grad():
329
+ return tts_model.infer(text, ref_codes, reference_text)
330
+ return await asyncio.to_thread(_infer)
331
+
332
  @app.post("/synthesize", response_model=TTSResponse)
333
  async def synthesize_speech(
334
  background_tasks: BackgroundTasks,
 
337
  reference_audio: UploadFile = File(...)
338
  ):
339
  """
340
+ Efficient synthesis with small-cache and minimal memory usage.
341
+ Uses a semaphore to limit concurrent CPU-bound inferences.
342
  """
343
  global active_requests
344
  start_time = time.time()
345
  request_id = str(uuid4())[:8]
346
  temp_ref_path = None
347
+
348
+ # Quick concurrency check
349
+ if _infer_semaphore.locked():
350
+ # If queue is full, respond quickly
351
+ raise HTTPException(status_code=503, detail="Server busy - try again shortly")
352
+
353
+ await _infer_semaphore.acquire()
354
  active_requests += 1
355
+
356
  try:
357
  if tts_model is None:
358
  raise HTTPException(status_code=503, detail="Model loading, please wait")
 
366
  # Process reference audio - creates temp file
367
  temp_ref_path = await MemoryOptimizedProcessor.process_reference_audio(reference_audio)
368
 
369
+ # Encode reference (run in thread)
370
+ ref_codes = await _encode_reference_async(temp_ref_path)
371
+
372
+ # Perform TTS (run in thread)
373
+ wav_output = await _infer_async(text, ref_codes, reference_text)
374
+
375
+ if not isinstance(wav_output, np.ndarray):
376
+ # Defensive conversion if needed
377
+ wav_output = np.asarray(wav_output, dtype=np.float32)
378
 
379
  # Generate audio ID and add to small cache
380
  audio_id = f"audio_{request_id}"
 
401
  raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
402
  finally:
403
  active_requests -= 1
404
+ _infer_semaphore.release()
405
  # Schedule cleanup of temp reference file
406
  if temp_ref_path and os.path.exists(temp_ref_path):
407
  background_tasks.add_task(cleanup_temp_file, temp_ref_path)
 
413
  reference_audio: UploadFile = File(...)
414
  ):
415
  """
416
+ Direct synthesis and streaming - inference runs off the event loop
417
+ and the resulting audio is streamed as chunks once ready.
418
  """
419
  global active_requests
420
  start_time = time.time()
421
  temp_ref_path = None
422
+
423
+ # Quick concurrency check
424
+ if _infer_semaphore.locked():
425
+ raise HTTPException(status_code=503, detail="Server busy - try again shortly")
426
+
427
+ await _infer_semaphore.acquire()
428
  active_requests += 1
429
 
430
  try:
 
434
  # Process reference audio
435
  temp_ref_path = await MemoryOptimizedProcessor.process_reference_audio(reference_audio)
436
 
437
+ # Encode & infer in background (off event loop)
438
+ ref_codes = await _encode_reference_async(temp_ref_path)
439
+ wav_output = await _infer_async(text, ref_codes, reference_text)
440
+
441
+ if not isinstance(wav_output, np.ndarray):
442
+ wav_output = np.asarray(wav_output, dtype=np.float32)
443
 
444
  # Convert to WAV bytes in memory
445
  wav_buffer = io.BytesIO()
 
450
 
451
  logger.info(f"Stream synthesis completed: {processing_time:.2f}s")
452
 
453
+ # Stream directly without storing on disk
454
  async def generate_stream():
455
  chunk_size = config.CHUNK_SIZE
456
  for i in range(0, len(wav_bytes), chunk_size):
 
467
  }
468
  )
469
 
470
+ except HTTPException:
471
+ raise
472
  except Exception as e:
473
  logger.error(f"Stream synthesis error: {str(e)}")
474
  raise HTTPException(status_code=500, detail=f"Stream synthesis failed: {str(e)}")
475
  finally:
476
  active_requests -= 1
477
+ _infer_semaphore.release()
478
  if temp_ref_path and os.path.exists(temp_ref_path):
479
  asyncio.create_task(cleanup_temp_file(temp_ref_path))
480