Rajhuggingface4253 commited on
Commit
ff7d020
·
verified ·
1 Parent(s): 8c1e9c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -172
app.py CHANGED
@@ -1,190 +1,57 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
 
 
2
  from fastapi.responses import FileResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
- from typing import Optional
6
- import uvicorn
7
- import tempfile
8
- import os
9
- import time
10
- import logging
11
- from pathlib import Path
12
-
13
- # NeuTTS Air imports
14
  from neuttsair.neutts import NeuTTSAir
15
- import soundfile as sf
16
-
17
- # Configure logging
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
 
21
- app = FastAPI(
22
- title="NeuTTS Air API",
23
- description="Professional Text-to-Speech with Instant Voice Cloning",
24
- version="1.0.0",
25
- docs_url="/docs",
26
- redoc_url="/redoc"
27
- )
28
 
29
- # CORS middleware
30
- app.add_middleware(
31
- CORSMiddleware,
32
- allow_origins=["*"],
33
- allow_credentials=True,
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
 
38
- # Pydantic models for request validation
39
  class TTSRequest(BaseModel):
40
  text: str
41
- ref_text: Optional[str] = ""
42
- language: Optional[str] = "en"
43
-
44
- class HealthResponse(BaseModel):
45
- status: str
46
- model_loaded: bool
47
- timestamp: str
48
-
49
- # Global model instance
50
- tts_model = None
51
-
52
- @app.on_event("startup")
53
- async def startup_event():
54
- """Initialize the TTS model on startup"""
55
- global tts_model
56
- try:
57
- logger.info("Loading NeuTTS Air model...")
58
-
59
- tts_model = NeuTTSAir(
60
- backbone_repo="neuphonic/neutts-air-q4-gguf",
61
- backbone_device="cpu",
62
- codec_repo="neuphonic/neucodec",
63
- codec_device="cpu"
64
- )
65
-
66
- logger.info("✅ NeuTTS Air model loaded successfully")
67
- except Exception as e:
68
- logger.error(f"❌ Failed to load NeuTTS Air model: {e}")
69
- raise
70
-
71
- @app.get("/", include_in_schema=False)
72
- async def root():
73
- return {"message": "NeuTTS Air API", "status": "running"}
74
 
75
- @app.get("/health", response_model=HealthResponse)
76
- async def health_check():
77
- """Health check endpoint"""
78
- return HealthResponse(
79
- status="healthy",
80
- model_loaded=tts_model is not None,
81
- timestamp=time.strftime("%Y-%m-%d %H:%M:%S")
82
- )
83
 
84
- @app.post("/synthesize")
85
- async def synthesize_speech(
86
- text: str = Form(..., description="Text to synthesize"),
87
- ref_audio: UploadFile = File(..., description="Reference audio file (3-15 seconds)"),
88
- ref_text: str = Form("", description="Transcript of reference audio")
89
- ):
90
  """
91
- Synthesize speech from text using a reference audio for voice cloning
92
  """
93
- if tts_model is None:
94
- raise HTTPException(status_code=503, detail="TTS model not loaded")
95
-
96
- # Validate audio file
97
- if not ref_audio.content_type.startswith('audio/'):
98
- raise HTTPException(status_code=400, detail="Invalid audio file format")
99
-
100
- try:
101
- # Save uploaded audio to temporary file
102
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_ref:
103
- content = await ref_audio.read()
104
- temp_ref.write(content)
105
- ref_audio_path = temp_ref.name
106
-
107
- # Generate speech
108
- logger.info(f"Synthesizing: '{text}'")
109
-
110
- ref_codes = tts_model.encode_reference(ref_audio_path)
111
- audio_data = tts_model.infer(text, ref_codes, ref_text)
112
-
113
- # Save output to temporary file
114
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output:
115
- sf.write(temp_output.name, audio_data, 24000)
116
- output_path = temp_output.name
117
-
118
- # Cleanup input file
119
- os.unlink(ref_audio_path)
120
-
121
- # Return audio file
122
- return FileResponse(
123
- output_path,
124
- media_type='audio/wav',
125
- filename=f"generated_speech_{int(time.time())}.wav",
126
- background=BackgroundTask(lambda: os.unlink(output_path))
127
- )
128
-
129
- except Exception as e:
130
- logger.error(f"Synthesis error: {e}")
131
- raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
132
 
133
- @app.post("/synthesize-from-sample")
134
- async def synthesize_from_sample(request: TTSRequest):
135
- """
136
- Synthesize speech using built-in sample voices
137
- """
138
- if tts_model is None:
139
- raise HTTPException(status_code=503, detail="TTS model not loaded")
140
-
141
  try:
142
- # Use built-in sample (Dave)
143
- sample_path = "samples/dave.wav"
144
- if not os.path.exists(sample_path):
145
- raise HTTPException(status_code=500, detail="Sample audio not found")
146
-
147
- ref_codes = tts_model.encode_reference(sample_path)
148
- audio_data = tts_model.infer(request.text, ref_codes, "My name is Dave and I'm from London.")
149
 
150
- # Save output to temporary file
151
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output:
152
- sf.write(temp_output.name, audio_data, 24000)
153
- output_path = temp_output.name
154
 
155
- return FileResponse(
156
- output_path,
157
- media_type='audio/wav',
158
- filename=f"sample_speech_{int(time.time())}.wav",
159
- background=BackgroundTask(lambda: os.unlink(output_path))
160
- )
161
-
162
- except Exception as e:
163
- logger.error(f"Sample synthesis error: {e}")
164
- raise HTTPException(status_code=500, detail=f"Sample synthesis failed: {str(e)}")
165
 
166
- @app.get("/voices/samples")
167
- async def get_sample_voices():
168
- """Get available sample voices"""
169
- samples_dir = Path("samples")
170
- samples = []
171
-
172
- if samples_dir.exists():
173
- for file in samples_dir.glob("*.wav"):
174
- samples.append({
175
- "name": file.stem,
176
- "path": str(file),
177
- "size": file.stat().st_size
178
- })
179
-
180
- return {"samples": samples}
181
 
182
- if __name__ == "__main__":
183
- uvicorn.run(
184
- "app:app",
185
- host="0.0.0.0",
186
- port=7860,
187
- reload=False, # Disable reload in production
188
- workers=1, # Single worker for CPU optimization
189
- access_log=True
190
- )
 
1
+ import tempfile
2
+ import soundfile as sf
3
+ from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import FileResponse
 
5
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
6
  from neuttsair.neutts import NeuTTSAir
 
 
 
 
 
7
 
8
+ # Initialize FastAPI app
9
+ app = FastAPI(title="NeuTTS-Air API", description="A FastAPI service for the NeuTTS-Air model.")
 
 
 
 
 
10
 
11
+ # Load the NeuTTS-Air model
12
+ # The path is relative to the working directory in the Docker container
13
+ MODEL_PATH = "neutts-air-q4-gguf"
14
+ try:
15
+ tts = NeuTTSAir(backbone_repo=MODEL_PATH, backbone_device="cpu")
16
+ except Exception as e:
17
+ print(f"Error loading model: {e}")
18
+ tts = None
19
 
20
+ # Pydantic model for the request body
21
  class TTSRequest(BaseModel):
22
  text: str
23
+ ref_audio_path: str
24
+ ref_text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ @app.get("/")
27
+ def read_root():
28
+ """Simple health check endpoint."""
29
+ return {"message": "NeuTTS-Air FastAPI is running."}
 
 
 
 
30
 
31
+ @app.post("/tts", summary="Generate speech from text")
32
+ async def tts_endpoint(request: TTSRequest):
 
 
 
 
33
  """
34
+ Generates a WAV audio file from text using a reference audio and transcript.
35
  """
36
+ if tts is None:
37
+ raise HTTPException(status_code=503, detail="Model is not loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
39
  try:
40
+ # Load the reference audio
41
+ # Note: You must provide a valid path to an audio file
42
+ # The user will need to upload their own reference audios or use pre-uploaded ones
43
+ ref_codes = tts.encode_reference(request.ref_audio_path)
 
 
 
44
 
45
+ # Perform inference
46
+ wav_audio = tts.infer(request.text, ref_codes, request.ref_text)
 
 
47
 
48
+ # Save the audio to a temporary file
49
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
50
+ sf.write(tmp.name, wav_audio, tts.codec.sampling_rate)
51
+ filepath = tmp.name
 
 
 
 
 
 
52
 
53
+ # Return the audio file
54
+ return FileResponse(filepath, media_type="audio/wav", filename="generated_speech.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")