Rajhuggingface4253 commited on
Commit
37a7804
·
verified ·
1 Parent(s): 2197f32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -40
app.py CHANGED
@@ -1,57 +1,344 @@
1
- import tempfile
2
- import soundfile as sf
3
- from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import FileResponse
 
 
 
 
 
 
5
  from pydantic import BaseModel
6
- from neuttsair.neutts import NeuTTSAir
 
 
7
 
8
- # Initialize FastAPI app
9
- app = FastAPI(title="NeuTTS-Air API", description="A FastAPI service for the NeuTTS-Air model.")
10
 
11
- # Load the NeuTTS-Air model
12
- # The path is relative to the working directory in the Docker container
13
- MODEL_PATH = "neutts-air-q4-gguf"
14
- try:
15
- tts = NeuTTSAir(backbone_repo=MODEL_PATH, backbone_device="cpu")
16
- except Exception as e:
17
- print(f"Error loading model: {e}")
18
- tts = None
19
 
20
- # Pydantic model for the request body
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class TTSRequest(BaseModel):
22
  text: str
23
- ref_audio_path: str
24
- ref_text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  @app.get("/")
27
- def read_root():
28
- """Simple health check endpoint."""
29
- return {"message": "NeuTTS-Air FastAPI is running."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- @app.post("/tts", summary="Generate speech from text")
32
- async def tts_endpoint(request: TTSRequest):
 
 
 
 
33
  """
34
- Generates a WAV audio file from text using a reference audio and transcript.
35
  """
36
- if tts is None:
37
- raise HTTPException(status_code=503, detail="Model is not loaded.")
38
-
 
 
 
 
 
 
 
 
 
 
39
  try:
40
- # Load the reference audio
41
- # Note: You must provide a valid path to an audio file
42
- # The user will need to upload their own reference audios or use pre-uploaded ones
43
- ref_codes = tts.encode_reference(request.ref_audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Perform inference
46
- wav_audio = tts.infer(request.text, ref_codes, request.ref_text)
47
 
48
- # Save the audio to a temporary file
49
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
50
- sf.write(tmp.name, wav_audio, tts.codec.sampling_rate)
51
- filepath = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Return the audio file
54
- return FileResponse(filepath, media_type="audio/wav", filename="generated_speech.wav")
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
57
- raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import gc
5
+ import torch
6
+ import numpy as np
7
+ import aiofiles
8
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
9
+ from fastapi.responses import JSONResponse, FileResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
+ from typing import Optional, Dict, Any
13
+ import psutil
14
+ import logging
15
 
16
+ # Add NeuTTS Air to path
17
+ sys.path.append("neutts-air")
18
 
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
 
 
 
 
 
22
 
23
+ app = FastAPI(
24
+ title="NeuTTS Air API",
25
+ description="High-quality on-device Text-to-Speech with instant voice cloning",
26
+ version="1.0.0"
27
+ )
28
+
29
+ # CORS middleware
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # Global model instance
39
+ tts_model = None
40
+ model_loading = False
41
+
42
+ # Pydantic models
43
  class TTSRequest(BaseModel):
44
  text: str
45
+ reference_text: str
46
+ reference_audio_path: Optional[str] = None
47
+
48
+ class TTSResponse(BaseModel):
49
+ success: bool
50
+ audio_url: Optional[str] = None
51
+ message: Optional[str] = None
52
+ processing_time: Optional[float] = None
53
+ audio_duration: Optional[float] = None
54
+
55
+ class HealthResponse(BaseModel):
56
+ status: str
57
+ model_loaded: bool
58
+ memory_usage: Dict[str, float]
59
+ disk_usage: Dict[str, float]
60
+
61
+ def load_tts_model():
62
+ global tts_model, model_loading
63
+
64
+ if tts_model is not None or model_loading:
65
+ return
66
+
67
+ model_loading = True
68
+ try:
69
+ logger.info("Loading NeuTTS Air model...")
70
+
71
+ # Try to import with fallbacks
72
+ try:
73
+ from neuttsair.neutts import NeuTTSAir
74
+ except ImportError as e:
75
+ logger.error(f"Failed to import NeuTTS Air: {e}")
76
+ # Try alternative import path
77
+ sys.path.insert(0, "/app/neutts-air")
78
+ from neuttsair.neutts import NeuTTSAir
79
+
80
+ # Use CPU for Hugging Face free tier with fallback models
81
+ try:
82
+ tts_model = NeuTTSAir(
83
+ backbone_repo="neuphonic/neutts-air-q4-gguf",
84
+ backbone_device="cpu",
85
+ codec_repo="neuphonic/neucodec",
86
+ codec_device="cpu"
87
+ )
88
+ except Exception as model_error:
89
+ logger.warning(f"Q4 model failed, trying default: {model_error}")
90
+ # Fallback to default model
91
+ tts_model = NeuTTSAir(
92
+ backbone_repo="neuphonic/neutts-air",
93
+ backbone_device="cpu",
94
+ codec_repo="neuphonic/neucodec",
95
+ codec_device="cpu"
96
+ )
97
+
98
+ logger.info("NeuTTS Air model loaded successfully!")
99
+
100
+ except Exception as e:
101
+ logger.error(f"Failed to load model: {str(e)}")
102
+ model_loading = False
103
+ raise e
104
+
105
+ model_loading = False
106
+
107
+ @app.on_event("startup")
108
+ async def startup_event():
109
+ """Load model on startup with error handling"""
110
+ try:
111
+ load_tts_model()
112
+ except Exception as e:
113
+ logger.error(f"Startup model loading failed: {e}")
114
 
115
  @app.get("/")
116
+ async def root():
117
+ return {"message": "NeuTTS Air API is running!", "status": "healthy"}
118
+
119
+ @app.get("/health")
120
+ async def health_check():
121
+ """Health check endpoint"""
122
+ try:
123
+ memory = psutil.virtual_memory()
124
+ disk = psutil.disk_usage('/')
125
+
126
+ return HealthResponse(
127
+ status="healthy",
128
+ model_loaded=tts_model is not None,
129
+ memory_usage={
130
+ "total_gb": round(memory.total / (1024**3), 2),
131
+ "available_gb": round(memory.available / (1024**3), 2),
132
+ "used_percent": round(memory.percent, 2)
133
+ },
134
+ disk_usage={
135
+ "total_gb": round(disk.total / (1024**3), 2),
136
+ "free_gb": round(disk.free / (1024**3), 2),
137
+ "used_percent": round(disk.percent, 2)
138
+ }
139
+ )
140
+ except Exception as e:
141
+ return HealthResponse(
142
+ status="degraded",
143
+ model_loaded=tts_model is not None,
144
+ memory_usage={"error": str(e)},
145
+ disk_usage={"error": str(e)}
146
+ )
147
 
148
+ @app.post("/synthesize")
149
+ async def synthesize_speech(
150
+ reference_text: str = Form(...),
151
+ text: str = Form(...),
152
+ reference_audio: UploadFile = File(...)
153
+ ):
154
  """
155
+ Synthesize speech using reference audio and text
156
  """
157
+ start_time = time.time()
158
+
159
+ if tts_model is None:
160
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
161
+
162
+ # Validate inputs
163
+ if not reference_text.strip() or not text.strip():
164
+ raise HTTPException(status_code=400, detail="Text fields cannot be empty")
165
+
166
+ if len(text) > 1000:
167
+ raise HTTPException(status_code=400, detail="Text too long. Maximum 1000 characters allowed.")
168
+
169
+ temp_ref_path = None
170
  try:
171
+ # Save uploaded file temporarily
172
+ temp_dir = "temp_audio"
173
+ os.makedirs(temp_dir, exist_ok=True)
174
+
175
+ file_extension = os.path.splitext(reference_audio.filename)[1] or ".wav"
176
+ temp_ref_path = os.path.join(temp_dir, f"ref_{int(time.time())}{file_extension}")
177
+
178
+ async with aiofiles.open(temp_ref_path, 'wb') as out_file:
179
+ content = await reference_audio.read()
180
+ await out_file.write(content)
181
+
182
+ # Validate audio file
183
+ try:
184
+ import librosa
185
+ audio_duration = librosa.get_duration(filename=temp_ref_path)
186
+ if audio_duration < 2 or audio_duration > 30:
187
+ raise HTTPException(
188
+ status_code=400,
189
+ detail=f"Audio duration ({audio_duration:.1f}s) should be between 3-15 seconds"
190
+ )
191
+ except Exception as e:
192
+ raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
193
+
194
+ # Perform TTS
195
+ logger.info(f"Starting synthesis for text: {text[:50]}...")
196
+
197
+ # Encode reference
198
+ ref_codes = tts_model.encode_reference(temp_ref_path)
199
 
200
+ # Generate speech
201
+ wav = tts_model.infer(text, ref_codes, reference_text)
202
 
203
+ # Save output
204
+ output_dir = "generated_audio"
205
+ os.makedirs(output_dir, exist_ok=True)
206
+ output_filename = f"output_{int(time.time())}.wav"
207
+ output_path = os.path.join(output_dir, output_filename)
208
+
209
+ import soundfile as sf
210
+ sf.write(output_path, wav, 24000)
211
+
212
+ processing_time = time.time() - start_time
213
+ audio_duration = len(wav) / 24000
214
+
215
+ logger.info(f"Synthesis completed in {processing_time:.2f}s")
216
+
217
+ return TTSResponse(
218
+ success=True,
219
+ audio_url=f"/audio/{output_filename}",
220
+ message="Speech synthesized successfully",
221
+ processing_time=round(processing_time, 2),
222
+ audio_duration=round(audio_duration, 2)
223
+ )
224
+
225
+ except Exception as e:
226
+ logger.error(f"Synthesis error: {str(e)}")
227
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
228
+
229
+ finally:
230
+ # Clean up temporary file
231
+ if temp_ref_path and os.path.exists(temp_ref_path):
232
+ try:
233
+ os.remove(temp_ref_path)
234
+ except:
235
+ pass
236
 
237
+ @app.get("/audio/{filename}")
238
+ async def get_audio_file(filename: str):
239
+ """Serve generated audio files"""
240
+ file_path = os.path.join("generated_audio", filename)
241
+
242
+ if not os.path.exists(file_path):
243
+ raise HTTPException(status_code=404, detail="Audio file not found")
244
+
245
+ return FileResponse(
246
+ file_path,
247
+ media_type="audio/wav",
248
+ filename=f"generated_speech_{filename}"
249
+ )
250
 
251
+ @app.post("/synthesize-with-url")
252
+ async def synthesize_with_url(request: TTSRequest):
253
+ """
254
+ Synthesize speech using a pre-uploaded reference audio file path
255
+ """
256
+ start_time = time.time()
257
+
258
+ if tts_model is None:
259
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
260
+
261
+ if not request.reference_audio_path or not os.path.exists(request.reference_audio_path):
262
+ raise HTTPException(status_code=400, detail="Reference audio path not found")
263
+
264
+ try:
265
+ # Validate audio file
266
+ import librosa
267
+ audio_duration = librosa.get_duration(filename=request.reference_audio_path)
268
+ if audio_duration < 2 or audio_duration > 30:
269
+ raise HTTPException(
270
+ status_code=400,
271
+ detail=f"Audio duration ({audio_duration:.1f}s) should be between 3-15 seconds"
272
+ )
273
+
274
+ # Perform TTS
275
+ logger.info(f"Starting synthesis for text: {request.text[:50]}...")
276
+
277
+ # Encode reference
278
+ ref_codes = tts_model.encode_reference(request.reference_audio_path)
279
+
280
+ # Generate speech
281
+ wav = tts_model.infer(request.text, ref_codes, request.reference_text)
282
+
283
+ # Save output
284
+ output_dir = "generated_audio"
285
+ os.makedirs(output_dir, exist_ok=True)
286
+ output_filename = f"output_{int(time.time())}.wav"
287
+ output_path = os.path.join(output_dir, output_filename)
288
+
289
+ import soundfile as sf
290
+ sf.write(output_path, wav, 24000)
291
+
292
+ processing_time = time.time() - start_time
293
+ audio_duration = len(wav) / 24000
294
+
295
+ return TTSResponse(
296
+ success=True,
297
+ audio_url=f"/audio/{output_filename}",
298
+ message="Speech synthesized successfully",
299
+ processing_time=round(processing_time, 2),
300
+ audio_duration=round(audio_duration, 2)
301
+ )
302
+
303
+ except Exception as e:
304
+ logger.error(f"Synthesis error: {str(e)}")
305
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
306
+
307
+ @app.delete("/cleanup")
308
+ async def cleanup_audio_files():
309
+ """Clean up generated audio files older than 1 hour"""
310
+ try:
311
+ output_dir = "generated_audio"
312
+ temp_dir = "temp_audio"
313
+
314
+ deleted_count = 0
315
+ current_time = time.time()
316
+
317
+ # Clean generated audio
318
+ if os.path.exists(output_dir):
319
+ for filename in os.listdir(output_dir):
320
+ file_path = os.path.join(output_dir, filename)
321
+ if os.path.isfile(file_path):
322
+ file_age = current_time - os.path.getctime(file_path)
323
+ if file_age > 3600: # 1 hour
324
+ os.remove(file_path)
325
+ deleted_count += 1
326
+
327
+ # Clean temp audio
328
+ if os.path.exists(temp_dir):
329
+ for filename in os.listdir(temp_dir):
330
+ file_path = os.path.join(temp_dir, filename)
331
+ if os.path.isfile(file_path):
332
+ file_age = current_time - os.path.getctime(file_path)
333
+ if file_age > 3600: # 1 hour
334
+ os.remove(file_path)
335
+ deleted_count += 1
336
+
337
+ return {"message": f"Cleaned up {deleted_count} files"}
338
+
339
  except Exception as e:
340
+ raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
341
+
342
+ if __name__ == "__main__":
343
+ import uvicorn
344
+ uvicorn.run(app, host="0.0.0.0", port=7860)