Rajhuggingface4253 commited on
Commit
d512c0d
·
verified ·
1 Parent(s): b3fe36f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -103
app.py CHANGED
@@ -1,161 +1,314 @@
 
1
  import os
2
- import io
3
- import base64
4
- import json
5
- import asyncio
6
  import logging
7
- from concurrent.futures import ThreadPoolExecutor
8
  from contextlib import asynccontextmanager
 
9
 
10
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
11
- from fastapi.responses import Response, JSONResponse
12
- import soundfile as sf
 
 
13
 
14
- from neutts_wrapper import NeuTTSWrapper
 
 
 
 
 
 
15
 
16
- # --- Configuration & Global Objects ---
17
  logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Read device from environment variable, defaulting to 'cpu'
21
- DEVICE = os.getenv("MODEL_DEVICE", "cpu")
22
- # Use a ThreadPoolExecutor to run blocking ML code in a separate thread
23
- tts_executor = ThreadPoolExecutor(max_workers=1)
 
 
 
24
 
25
- # --- Lifespan Management (Model Loading) ---
26
  @asynccontextmanager
27
  async def lifespan(app: FastAPI):
28
- """
29
- Manages the model's lifecycle. It's loaded at startup and resources are
30
- cleaned up at shutdown.
31
- """
32
- logger.info("Application startup...")
33
  try:
34
- # Load the model wrapper into the application state
35
- app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
36
  except Exception as e:
37
- logger.error(f"FATAL: Model could not be loaded. {e}")
38
- app.state.tts_wrapper = None
39
-
40
- yield # The application is now running
41
-
42
- logger.info("Application shutdown...")
43
- tts_executor.shutdown(wait=True)
44
 
45
- # --- FastAPI App Initialization ---
46
  app = FastAPI(
47
  title="NeuTTS Air Production API",
48
  description="Production-ready Text-to-Speech with Voice Cloning",
49
  version="2.0.0",
 
50
  lifespan=lifespan
51
  )
52
 
53
- # --- Helper function for running blocking code ---
54
- async def run_in_executor(func, *args):
55
- """Runs a blocking function in the thread pool to avoid blocking the server."""
 
 
 
 
 
 
 
56
  loop = asyncio.get_event_loop()
57
- return await loop.run_in_executor(tts_executor, func, *args)
 
 
 
 
58
 
59
- # --- API Endpoints ---
60
  @app.get("/")
61
  async def root():
62
- return {"status": "online", "service": "NeuTTS Air API v2"}
 
 
 
 
 
 
63
 
64
  @app.get("/health")
65
  async def health_check():
66
- model_status = "loaded" if app.state.tts_wrapper else "degraded"
67
- return {"status": "healthy", "model_status": model_status, "device": DEVICE}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  @app.post("/api/v1/synthesize")
70
  async def synthesize_speech(
71
- ref_text: str = Form(...),
72
- gen_text: str = Form(...),
73
- ref_audio: UploadFile = File(...)
 
74
  ):
75
- if not app.state.tts_wrapper:
76
- raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
77
-
 
 
 
78
  try:
79
- ref_audio_bytes = await ref_audio.read()
 
 
80
 
81
- # Run blocking ML code in the thread pool
82
- ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
83
- wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
84
 
85
- # Process audio in-memory
86
- buffer = io.BytesIO()
87
- sf.write(buffer, wav_data, 24000, format='WAV')
88
- buffer.seek(0)
89
 
90
- return Response(content=buffer.read(), media_type="audio/wav")
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
- logger.error(f"Synthesis failed: {e}")
94
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
 
 
 
 
95
 
96
  @app.post("/api/v1/synthesize/b64")
97
  async def synthesize_speech_base64(
98
  ref_text: str = Form(...),
99
- gen_text: str = Form(...),
100
- ref_audio: UploadFile = File(...)
 
101
  ):
102
- if not app.state.tts_wrapper:
103
- raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
104
-
 
 
 
105
  try:
106
- ref_audio_bytes = await ref_audio.read()
 
 
 
 
 
107
 
108
- # Run blocking ML code in the thread pool
109
- ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
110
- wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
 
 
 
 
111
 
112
- # Process audio in-memory
 
 
 
113
  buffer = io.BytesIO()
114
- sf.write(buffer, wav_data, 24000, format='WAV')
115
  buffer.seek(0)
116
 
 
117
  audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
118
 
119
- return JSONResponse({"audio_data": audio_b64, "format": "wav"})
120
-
 
 
 
 
 
121
  except Exception as e:
122
- logger.error(f"Base64 synthesis failed: {e}")
123
- raise HTTPException(status_code=500, detail=f"Base64 synthesis failed: {str(e)}")
 
 
 
124
 
125
- @app.post("/api/v1/batch-synthesize")
126
- async def batch_synthesize(
127
  ref_text: str = Form(...),
 
128
  ref_audio: UploadFile = File(...),
129
- texts: str = Form(...)
130
  ):
131
- if not app.state.tts_wrapper:
132
- raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
 
 
 
133
 
134
  try:
135
- text_list = json.loads(texts)
136
- if not isinstance(text_list, list):
137
- raise ValueError("Texts must be a JSON array of strings.")
138
-
139
- ref_audio_bytes = await ref_audio.read()
140
-
141
- # Encode reference once, in the thread pool
142
- ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
143
-
144
- results = []
145
- for text in text_list:
146
- # Infer for each text
147
- wav_data = await run_in_executor(app.state.tts_wrapper.infer, text, ref_codes, ref_text)
148
-
149
- buffer = io.BytesIO()
150
- sf.write(buffer, wav_data, 24000, format='WAV')
151
- buffer.seek(0)
152
- audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
153
- results.append({"text": text, "audio_data": audio_b64})
154
-
155
- return JSONResponse({"generated_clips": results})
156
-
157
- except json.JSONDecodeError:
158
- raise HTTPException(status_code=400, detail="Invalid JSON in 'texts' field.")
 
 
 
 
 
 
 
 
159
  except Exception as e:
160
- logger.error(f"Batch synthesis failed: {e}")
161
- raise HTTPException(status_code=500, detail=f"Batch synthesis failed: {str(e)}")
 
 
 
 
 
 
 
 
1
+ # [file name]: app.py
2
  import os
3
+ import sys
 
 
 
4
  import logging
5
+ from typing import Optional
6
  from contextlib import asynccontextmanager
7
+ from concurrent.futures import ThreadPoolExecutor
8
 
9
+ # CRITICAL: Set environment variables BEFORE any imports
10
+ os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache'
11
+ os.environ['HF_HOME'] = '/app/cache'
12
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/cache'
13
+ os.environ['HF_HUB_DISABLE_LOCKING'] = '1'
14
 
15
+ # Add neutts-air to Python path
16
+ neutts_path = os.path.join(os.getcwd(), "neutts-air")
17
+ sys.path.insert(0, neutts_path)
18
+
19
+ # Create cache directories
20
+ os.makedirs('/app/cache', exist_ok=True)
21
+ os.makedirs('/tmp/numba_cache', exist_ok=True)
22
 
 
23
  logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger("neutts-production-api")
25
+
26
+ try:
27
+ import numpy as np
28
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
29
+ from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
30
+ from fastapi.middleware.cors import CORSMiddleware
31
+ import soundfile as sf
32
+ import io
33
+ import asyncio
34
+ import uuid
35
+
36
+ from neutts_wrapper import NeuTTSWrapper, TTSRequest
37
+
38
+ logger.info("✅ All imports successful")
39
+
40
+ except ImportError as e:
41
+ logger.error(f"❌ Import failed: {e}")
42
+ raise
43
 
44
+ # Device detection and resource management
45
+ def get_best_device():
46
+ return "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ DEVICE = get_best_device()
49
+ MAX_WORKERS = 1 if DEVICE == "cpu" else 2
50
+ tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
51
 
 
52
  @asynccontextmanager
53
  async def lifespan(app: FastAPI):
54
+ """Modern lifespan management with proper cleanup"""
 
 
 
 
55
  try:
56
+ app.state.neutts_wrapper = NeuTTSWrapper(device=DEVICE)
57
+ logger.info(f"✅ Model loaded on {DEVICE}")
58
  except Exception as e:
59
+ logger.error(f" Model loading failed: {e}")
60
+ raise
61
+ yield
62
+ # Cleanup
63
+ tts_executor.shutdown(wait=False)
64
+ if hasattr(app.state, 'neutts_wrapper'):
65
+ app.state.neutts_wrapper._cleanup_temp_files()
66
 
 
67
  app = FastAPI(
68
  title="NeuTTS Air Production API",
69
  description="Production-ready Text-to-Speech with Voice Cloning",
70
  version="2.0.0",
71
+ docs_url="/docs",
72
  lifespan=lifespan
73
  )
74
 
75
+ # CORS middleware
76
+ app.add_middleware(
77
+ CORSMiddleware,
78
+ allow_origins=["*"],
79
+ allow_methods=["*"],
80
+ allow_headers=["*"],
81
+ )
82
+
83
+ async def run_tts_async(tts_request: TTSRequest) -> np.ndarray:
84
+ """Offload blocking TTS call to thread pool"""
85
  loop = asyncio.get_event_loop()
86
+ return await loop.run_in_executor(
87
+ tts_executor,
88
+ app.state.neutts_wrapper.generate_speech,
89
+ tts_request
90
+ )
91
 
 
92
  @app.get("/")
93
  async def root():
94
+ return {
95
+ "status": "online",
96
+ "service": "NeuTTS Air Production API",
97
+ "version": "2.0.0",
98
+ "device": DEVICE,
99
+ "model_loaded": hasattr(app.state, 'neutts_wrapper')
100
+ }
101
 
102
  @app.get("/health")
103
  async def health_check():
104
+ """Comprehensive health check with memory monitoring"""
105
+ if not hasattr(app.state, 'neutts_wrapper'):
106
+ raise HTTPException(status_code=503, detail="Service unavailable")
107
+
108
+ try:
109
+ memory_info = app.state.neutts_wrapper.get_memory_usage()
110
+
111
+ return {
112
+ "status": "healthy",
113
+ "model_loaded": True,
114
+ "device": DEVICE,
115
+ "memory_usage": memory_info,
116
+ "endpoints": {
117
+ "synthesize": "/api/v1/synthesize",
118
+ "synthesize_b64": "/api/v1/synthesize/b64",
119
+ "synthesize_stream": "/api/v1/synthesize/stream",
120
+ "system_info": "/api/v1/system"
121
+ }
122
+ }
123
+ except Exception as e:
124
+ logger.error(f"Health check failed: {e}")
125
+ raise HTTPException(status_code=503, detail="Service degraded")
126
+
127
+ @app.get("/api/v1/system")
128
+ async def system_info():
129
+ """System information and resource monitoring"""
130
+ if not hasattr(app.state, 'neutts_wrapper'):
131
+ raise HTTPException(status_code=503, detail="Service unavailable")
132
+
133
+ memory_info = app.state.neutts_wrapper.get_memory_usage()
134
+
135
+ return {
136
+ "device": DEVICE,
137
+ "max_workers": MAX_WORKERS,
138
+ "memory_usage": memory_info,
139
+ "cache_info": {
140
+ "hf_cache": os.environ.get('HF_HOME'),
141
+ "numba_cache": os.environ.get('NUMBA_CACHE_DIR')
142
+ }
143
+ }
144
 
145
  @app.post("/api/v1/synthesize")
146
  async def synthesize_speech(
147
+ ref_text: str = Form(..., description="Reference audio transcript", max_length=1000),
148
+ gen_text: str = Form(..., description="Text to synthesize", max_length=5000),
149
+ ref_audio: UploadFile = File(..., description="Reference audio file (WAV, max 10MB)"),
150
+ use_gpu: bool = Form(True, description="Use GPU if available")
151
  ):
152
+ """Production-grade speech synthesis with voice cloning"""
153
+ if not hasattr(app.state, 'neutts_wrapper'):
154
+ raise HTTPException(status_code=503, detail="Service unavailable")
155
+
156
+ temp_file_path = None
157
+
158
  try:
159
+ # Validate file type
160
+ if not ref_audio.filename or not ref_audio.filename.lower().endswith('.wav'):
161
+ raise HTTPException(400, "Only WAV files are supported as reference audio")
162
 
163
+ # Read and validate file content
164
+ file_content = await ref_audio.read()
 
165
 
166
+ # Save uploaded file to temp location
167
+ temp_file_path = app.state.neutts_wrapper.save_uploaded_file(file_content)
 
 
168
 
169
+ # Create TTS request
170
+ tts_request = TTSRequest(
171
+ ref_text=ref_text.strip(),
172
+ gen_text=gen_text.strip(),
173
+ ref_audio_path=temp_file_path,
174
+ use_gpu=use_gpu and torch.cuda.is_available()
175
+ )
176
+
177
+ # Generate speech
178
+ audio_data = await run_tts_async(tts_request)
179
+
180
+ # Create output file
181
+ output_filename = f"synthesized_{uuid.uuid4()}.wav"
182
+ output_path = os.path.join(app.state.neutts_wrapper.temp_dir, output_filename)
183
+ sf.write(output_path, audio_data, 24000)
184
+
185
+ # Return file response with cleanup
186
+ return FileResponse(
187
+ output_path,
188
+ media_type="audio/wav",
189
+ filename=output_filename,
190
+ background=BackgroundTask(app.state.neutts_wrapper.cleanup_file, output_path)
191
+ )
192
+
193
+ except ValueError as e:
194
+ raise HTTPException(status_code=400, detail=str(e))
195
+ except RuntimeError as e:
196
+ raise HTTPException(status_code=500, detail=str(e))
197
  except Exception as e:
198
+ logger.error(f"Synthesis error: {str(e)}")
199
+ raise HTTPException(status_code=500, detail="Internal server error")
200
+ finally:
201
+ # Cleanup uploaded temp file
202
+ if temp_file_path:
203
+ app.state.neutts_wrapper.cleanup_file(temp_file_path)
204
 
205
  @app.post("/api/v1/synthesize/b64")
206
  async def synthesize_speech_base64(
207
  ref_text: str = Form(...),
208
+ gen_text: str = Form(...),
209
+ ref_audio: UploadFile = File(...),
210
+ use_gpu: bool = Form(True)
211
  ):
212
+ """Synthesize speech and return as base64 encoded audio"""
213
+ if not hasattr(app.state, 'neutts_wrapper'):
214
+ raise HTTPException(status_code=503, detail="Service unavailable")
215
+
216
+ temp_file_path = None
217
+
218
  try:
219
+ # Validate and save uploaded file
220
+ if not ref_audio.filename.lower().endswith('.wav'):
221
+ raise HTTPException(400, "Only WAV files are supported")
222
+
223
+ file_content = await ref_audio.read()
224
+ temp_file_path = app.state.neutts_wrapper.save_uploaded_file(file_content)
225
 
226
+ # Create TTS request
227
+ tts_request = TTSRequest(
228
+ ref_text=ref_text.strip(),
229
+ gen_text=gen_text.strip(),
230
+ ref_audio_path=temp_file_path,
231
+ use_gpu=use_gpu and torch.cuda.is_available()
232
+ )
233
 
234
+ # Generate speech
235
+ audio_data = await run_tts_async(tts_request)
236
+
237
+ # Convert to base64
238
  buffer = io.BytesIO()
239
+ sf.write(buffer, audio_data, 24000, format='WAV')
240
  buffer.seek(0)
241
 
242
+ import base64
243
  audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
244
 
245
+ return JSONResponse({
246
+ "audio_data": audio_b64,
247
+ "sample_rate": 24000,
248
+ "format": "wav",
249
+ "message": "Synthesis completed successfully"
250
+ })
251
+
252
  except Exception as e:
253
+ logger.error(f"Base64 synthesis error: {str(e)}")
254
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
255
+ finally:
256
+ if temp_file_path:
257
+ app.state.neutts_wrapper.cleanup_file(temp_file_path)
258
 
259
+ @app.post("/api/v1/synthesize/stream")
260
+ async def synthesize_speech_stream(
261
  ref_text: str = Form(...),
262
+ gen_text: str = Form(...),
263
  ref_audio: UploadFile = File(...),
264
+ use_gpu: bool = Form(True)
265
  ):
266
+ """Stream synthesized speech for immediate playback"""
267
+ if not hasattr(app.state, 'neutts_wrapper'):
268
+ raise HTTPException(status_code=503, detail="Service unavailable")
269
+
270
+ temp_file_path = None
271
 
272
  try:
273
+ # Validate and save uploaded file
274
+ file_content = await ref_audio.read()
275
+ temp_file_path = app.state.neutts_wrapper.save_uploaded_file(file_content)
276
+
277
+ # Create TTS request
278
+ tts_request = TTSRequest(
279
+ ref_text=ref_text.strip(),
280
+ gen_text=gen_text.strip(),
281
+ ref_audio_path=temp_file_path,
282
+ use_gpu=use_gpu and torch.cuda.is_available()
283
+ )
284
+
285
+ # Generate speech
286
+ audio_data = await run_tts_async(tts_request)
287
+
288
+ # Create streaming response
289
+ buffer = io.BytesIO()
290
+ sf.write(buffer, audio_data, 24000, format='MP3')
291
+ buffer.seek(0)
292
+
293
+ def generate():
294
+ yield buffer.read()
295
+
296
+ return StreamingResponse(
297
+ generate(),
298
+ media_type="audio/mpeg",
299
+ headers={
300
+ "Content-Disposition": "attachment; filename=streamed_speech.mp3",
301
+ "Cache-Control": "no-cache"
302
+ }
303
+ )
304
+
305
  except Exception as e:
306
+ logger.error(f"Streaming synthesis error: {str(e)}")
307
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
308
+ finally:
309
+ if temp_file_path:
310
+ app.state.neutts_wrapper.cleanup_file(temp_file_path)
311
+
312
+ if __name__ == "__main__":
313
+ import uvicorn
314
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)