Rajhuggingface4253 commited on
Commit
8a6294a
·
verified ·
1 Parent(s): 7123400

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -291
app.py CHANGED
@@ -1,314 +1,70 @@
1
- # [file name]: app.py
2
  import os
3
  import sys
 
 
 
 
 
 
 
 
 
4
  import logging
5
- from typing import Optional
6
- from contextlib import asynccontextmanager
7
- from concurrent.futures import ThreadPoolExecutor
8
-
9
- # CRITICAL: Set environment variables BEFORE any imports
10
- os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache'
11
- os.environ['HF_HOME'] = '/app/cache'
12
- os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/cache'
13
- os.environ['HF_HUB_DISABLE_LOCKING'] = '1'
14
-
15
- # Add neutts-air to Python path
16
- neutts_path = os.path.join(os.getcwd(), "neutts-air")
17
- sys.path.insert(0, neutts_path)
18
-
19
- # Create cache directories
20
- os.makedirs('/app/cache', exist_ok=True)
21
- os.makedirs('/tmp/numba_cache', exist_ok=True)
22
 
23
  logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger("neutts-production-api")
25
-
26
- try:
27
- import numpy as np
28
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
29
- from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
30
- from fastapi.middleware.cors import CORSMiddleware
31
- import soundfile as sf
32
- import io
33
- import asyncio
34
- import uuid
35
-
36
- from neutts_wrapper import NeuTTSWrapper, TTSRequest
37
-
38
- logger.info("✅ All imports successful")
39
-
40
- except ImportError as e:
41
- logger.error(f"❌ Import failed: {e}")
42
- raise
43
-
44
- # Device detection and resource management
45
- def get_best_device():
46
- return "cuda" if torch.cuda.is_available() else "cpu"
47
-
48
- DEVICE = get_best_device()
49
- MAX_WORKERS = 1 if DEVICE == "cpu" else 2
50
- tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
51
-
52
- @asynccontextmanager
53
- async def lifespan(app: FastAPI):
54
- """Modern lifespan management with proper cleanup"""
55
- try:
56
- app.state.neutts_wrapper = NeuTTSWrapper(device=DEVICE)
57
- logger.info(f"✅ Model loaded on {DEVICE}")
58
- except Exception as e:
59
- logger.error(f"❌ Model loading failed: {e}")
60
- raise
61
- yield
62
- # Cleanup
63
- tts_executor.shutdown(wait=False)
64
- if hasattr(app.state, 'neutts_wrapper'):
65
- app.state.neutts_wrapper._cleanup_temp_files()
66
-
67
- app = FastAPI(
68
- title="NeuTTS Air Production API",
69
- description="Production-ready Text-to-Speech with Voice Cloning",
70
- version="2.0.0",
71
- docs_url="/docs",
72
- lifespan=lifespan
73
- )
74
-
75
- # CORS middleware
76
- app.add_middleware(
77
- CORSMiddleware,
78
- allow_origins=["*"],
79
- allow_methods=["*"],
80
- allow_headers=["*"],
81
  )
82
 
83
- async def run_tts_async(tts_request: TTSRequest) -> np.ndarray:
84
- """Offload blocking TTS call to thread pool"""
85
- loop = asyncio.get_event_loop()
86
- return await loop.run_in_executor(
87
- tts_executor,
88
- app.state.neutts_wrapper.generate_speech,
89
- tts_request
90
- )
91
 
92
- @app.get("/")
93
- async def root():
94
- return {
95
- "status": "online",
96
- "service": "NeuTTS Air Production API",
97
- "version": "2.0.0",
98
- "device": DEVICE,
99
- "model_loaded": hasattr(app.state, 'neutts_wrapper')
100
- }
101
-
102
- @app.get("/health")
103
- async def health_check():
104
- """Comprehensive health check with memory monitoring"""
105
- if not hasattr(app.state, 'neutts_wrapper'):
106
- raise HTTPException(status_code=503, detail="Service unavailable")
107
-
108
  try:
109
- memory_info = app.state.neutts_wrapper.get_memory_usage()
110
-
111
- return {
112
- "status": "healthy",
113
- "model_loaded": True,
114
- "device": DEVICE,
115
- "memory_usage": memory_info,
116
- "endpoints": {
117
- "synthesize": "/api/v1/synthesize",
118
- "synthesize_b64": "/api/v1/synthesize/b64",
119
- "synthesize_stream": "/api/v1/synthesize/stream",
120
- "system_info": "/api/v1/system"
121
- }
122
- }
123
- except Exception as e:
124
- logger.error(f"Health check failed: {e}")
125
- raise HTTPException(status_code=503, detail="Service degraded")
126
-
127
- @app.get("/api/v1/system")
128
- async def system_info():
129
- """System information and resource monitoring"""
130
- if not hasattr(app.state, 'neutts_wrapper'):
131
- raise HTTPException(status_code=503, detail="Service unavailable")
132
-
133
- memory_info = app.state.neutts_wrapper.get_memory_usage()
134
-
135
- return {
136
- "device": DEVICE,
137
- "max_workers": MAX_WORKERS,
138
- "memory_usage": memory_info,
139
- "cache_info": {
140
- "hf_cache": os.environ.get('HF_HOME'),
141
- "numba_cache": os.environ.get('NUMBA_CACHE_DIR')
142
- }
143
- }
144
 
145
- @app.post("/api/v1/synthesize")
146
  async def synthesize_speech(
147
- ref_text: str = Form(..., description="Reference audio transcript", max_length=1000),
148
- gen_text: str = Form(..., description="Text to synthesize", max_length=5000),
149
- ref_audio: UploadFile = File(..., description="Reference audio file (WAV, max 10MB)"),
150
- use_gpu: bool = Form(True, description="Use GPU if available")
151
- ):
152
- """Production-grade speech synthesis with voice cloning"""
153
- if not hasattr(app.state, 'neutts_wrapper'):
154
- raise HTTPException(status_code=503, detail="Service unavailable")
155
-
156
- temp_file_path = None
157
-
158
- try:
159
- # Validate file type
160
- if not ref_audio.filename or not ref_audio.filename.lower().endswith('.wav'):
161
- raise HTTPException(400, "Only WAV files are supported as reference audio")
162
-
163
- # Read and validate file content
164
- file_content = await ref_audio.read()
165
-
166
- # Save uploaded file to temp location
167
- temp_file_path = app.state.neutts_wrapper.save_uploaded_file(file_content)
168
-
169
- # Create TTS request
170
- tts_request = TTSRequest(
171
- ref_text=ref_text.strip(),
172
- gen_text=gen_text.strip(),
173
- ref_audio_path=temp_file_path,
174
- use_gpu=use_gpu and torch.cuda.is_available()
175
- )
176
-
177
- # Generate speech
178
- audio_data = await run_tts_async(tts_request)
179
-
180
- # Create output file
181
- output_filename = f"synthesized_{uuid.uuid4()}.wav"
182
- output_path = os.path.join(app.state.neutts_wrapper.temp_dir, output_filename)
183
- sf.write(output_path, audio_data, 24000)
184
-
185
- # Return file response with cleanup
186
- return FileResponse(
187
- output_path,
188
- media_type="audio/wav",
189
- filename=output_filename,
190
- background=BackgroundTask(app.state.neutts_wrapper.cleanup_file, output_path)
191
- )
192
-
193
- except ValueError as e:
194
- raise HTTPException(status_code=400, detail=str(e))
195
- except RuntimeError as e:
196
- raise HTTPException(status_code=500, detail=str(e))
197
- except Exception as e:
198
- logger.error(f"Synthesis error: {str(e)}")
199
- raise HTTPException(status_code=500, detail="Internal server error")
200
- finally:
201
- # Cleanup uploaded temp file
202
- if temp_file_path:
203
- app.state.neutts_wrapper.cleanup_file(temp_file_path)
204
-
205
- @app.post("/api/v1/synthesize/b64")
206
- async def synthesize_speech_base64(
207
  ref_text: str = Form(...),
208
  gen_text: str = Form(...),
209
  ref_audio: UploadFile = File(...),
210
- use_gpu: bool = Form(True)
211
  ):
212
- """Synthesize speech and return as base64 encoded audio"""
213
- if not hasattr(app.state, 'neutts_wrapper'):
214
- raise HTTPException(status_code=503, detail="Service unavailable")
215
-
216
- temp_file_path = None
217
 
218
  try:
219
- # Validate and save uploaded file
220
- if not ref_audio.filename.lower().endswith('.wav'):
221
- raise HTTPException(400, "Only WAV files are supported")
222
-
223
- file_content = await ref_audio.read()
224
- temp_file_path = app.state.neutts_wrapper.save_uploaded_file(file_content)
225
-
226
- # Create TTS request
227
- tts_request = TTSRequest(
228
- ref_text=ref_text.strip(),
229
- gen_text=gen_text.strip(),
230
- ref_audio_path=temp_file_path,
231
- use_gpu=use_gpu and torch.cuda.is_available()
232
- )
233
-
234
- # Generate speech
235
- audio_data = await run_tts_async(tts_request)
236
 
237
- # Convert to base64
238
- buffer = io.BytesIO()
239
- sf.write(buffer, audio_data, 24000, format='WAV')
240
- buffer.seek(0)
241
 
242
- import base64
243
- audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
 
244
 
245
- return JSONResponse({
246
- "audio_data": audio_b64,
247
- "sample_rate": 24000,
248
- "format": "wav",
249
- "message": "Synthesis completed successfully"
250
- })
251
 
252
  except Exception as e:
253
- logger.error(f"Base64 synthesis error: {str(e)}")
254
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
255
- finally:
256
- if temp_file_path:
257
- app.state.neutts_wrapper.cleanup_file(temp_file_path)
258
 
259
- @app.post("/api/v1/synthesize/stream")
260
- async def synthesize_speech_stream(
261
- ref_text: str = Form(...),
262
- gen_text: str = Form(...),
263
- ref_audio: UploadFile = File(...),
264
- use_gpu: bool = Form(True)
265
- ):
266
- """Stream synthesized speech for immediate playback"""
267
- if not hasattr(app.state, 'neutts_wrapper'):
268
- raise HTTPException(status_code=503, detail="Service unavailable")
269
-
270
- temp_file_path = None
271
-
272
- try:
273
- # Validate and save uploaded file
274
- file_content = await ref_audio.read()
275
- temp_file_path = app.state.neutts_wrapper.save_uploaded_file(file_content)
276
-
277
- # Create TTS request
278
- tts_request = TTSRequest(
279
- ref_text=ref_text.strip(),
280
- gen_text=gen_text.strip(),
281
- ref_audio_path=temp_file_path,
282
- use_gpu=use_gpu and torch.cuda.is_available()
283
- )
284
-
285
- # Generate speech
286
- audio_data = await run_tts_async(tts_request)
287
-
288
- # Create streaming response
289
- buffer = io.BytesIO()
290
- sf.write(buffer, audio_data, 24000, format='MP3')
291
- buffer.seek(0)
292
-
293
- def generate():
294
- yield buffer.read()
295
-
296
- return StreamingResponse(
297
- generate(),
298
- media_type="audio/mpeg",
299
- headers={
300
- "Content-Disposition": "attachment; filename=streamed_speech.mp3",
301
- "Cache-Control": "no-cache"
302
- }
303
- )
304
-
305
- except Exception as e:
306
- logger.error(f"Streaming synthesis error: {str(e)}")
307
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
308
- finally:
309
- if temp_file_path:
310
- app.state.neutts_wrapper.cleanup_file(temp_file_path)
311
-
312
- if __name__ == "__main__":
313
- import uvicorn
314
- uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
 
 
1
  import os
2
  import sys
3
+ sys.path.insert(0, os.path.join(os.getcwd(), "neutts-air"))
4
+
5
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
6
+ from fastapi.responses import FileResponse, JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import io
11
+ import uuid
12
  import logging
13
+ from neuttsair.neutts import NeuTTSAir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Initialize model
19
+ tts = NeuTTSAir(
20
+ backbone_repo="neuphonic/neutts-air",
21
+ backbone_device="cpu", # Explicit CPU
22
+ codec_repo="neuphonic/neucodec",
23
+ codec_device="cpu" # Explicit CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
+ app = FastAPI(title="NeuTTS Air API")
27
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
28
 
29
+ def cleanup_file(file_path: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try:
31
+ if os.path.exists(file_path):
32
+ os.remove(file_path)
33
+ except:
34
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ @app.post("/synthesize")
37
  async def synthesize_speech(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ref_text: str = Form(...),
39
  gen_text: str = Form(...),
40
  ref_audio: UploadFile = File(...),
41
+ background_tasks: BackgroundTasks = None
42
  ):
43
+ temp_path = f"/tmp/{uuid.uuid4()}.wav"
 
 
 
 
44
 
45
  try:
46
+ # Save uploaded file
47
+ with open(temp_path, "wb") as f:
48
+ f.write(await ref_audio.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Core NeuTTS logic (same as working Gradio app)
51
+ ref_codes = tts.encode_reference(temp_path)
52
+ wav = tts.infer(gen_text, ref_codes, ref_text)
 
53
 
54
+ # Return audio
55
+ output_path = f"/tmp/{uuid.uuid4()}.wav"
56
+ sf.write(output_path, wav, 24000)
57
 
58
+ if background_tasks:
59
+ background_tasks.add_task(cleanup_file, temp_path)
60
+ background_tasks.add_task(cleanup_file, output_path)
61
+
62
+ return FileResponse(output_path, media_type="audio/wav")
 
63
 
64
  except Exception as e:
65
+ cleanup_file(temp_path)
66
+ raise HTTPException(500, f"Synthesis failed: {str(e)}")
 
 
 
67
 
68
+ @app.get("/health")
69
+ async def health_check():
70
+ return {"status": "healthy", "model_loaded": True}