Hameed13 commited on
Commit
ced0e79
·
verified ·
1 Parent(s): 6ee3346

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +340 -262
main.py CHANGED
@@ -6,331 +6,409 @@ import logging
6
  import traceback
7
  import requests
8
  import subprocess
9
- from datetime import datetime
10
  from pathlib import Path
11
-
12
- from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.responses import FileResponse, JSONResponse
 
14
  from pydantic import BaseModel
15
  import uvicorn
16
- import torch
17
- import torchaudio
18
- from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
20
  # Configure logging
21
  logging.basicConfig(
22
  level=logging.INFO,
23
- format='[%(asctime)s] %(levelname)s - %(message)s',
24
- datefmt='%Y-%m-%d %H:%M:%S'
25
  )
 
26
 
27
- # Print startup time
28
- logging.info(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
29
 
30
- # Initialize the app
31
- app = FastAPI()
32
 
33
- # Set environment variable to disable PortAudio requirement
34
- os.environ["OUTETTS_NO_PORTAUDIO"] = "1"
35
 
36
- # Define the paths for required model files
37
- tokenizer_path = "saheedniyi/YarnGPT2"
38
- wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
39
- wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
 
 
 
 
40
 
41
- # Patch torch.load to always use weights_only=False
42
- # This is necessary for compatibility with PyTorch 2.6+
43
- original_torch_load = torch.load
 
 
 
44
 
45
- def patched_torch_load(*args, **kwargs):
46
- if 'weights_only' not in kwargs:
47
- kwargs['weights_only'] = False
48
- return original_torch_load(*args, **kwargs)
 
 
49
 
50
- # Replace the original function with our patched version
51
- torch.load = patched_torch_load
 
 
 
52
 
53
- # Function to download files with proper error handling
54
- def download_file(url, destination):
55
- logging.info(f"Downloading file from {url} to {destination}")
56
- try:
57
- # Try to download using requests
58
- response = requests.get(url, stream=True)
59
- response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- with open(destination, 'wb') as f:
62
- for chunk in response.iter_content(chunk_size=8192):
63
- f.write(chunk)
 
 
 
64
 
65
- logging.info(f"Successfully downloaded file to {destination}")
66
- return True
67
- except Exception as e:
68
- logging.error(f"Failed to download file using requests: {str(e)}")
69
 
70
- # Fallback to wget
71
  try:
72
- logging.info("Trying alternate download method with wget...")
73
- result = subprocess.run(['wget', url, '-O', destination],
74
- check=True,
75
- capture_output=True,
76
- text=True)
77
- logging.info(f"wget download successful")
78
- return True
79
- except subprocess.CalledProcessError as e:
80
- logging.error(f"wget download failed: {e.stderr}")
81
-
82
- # Final fallback to curl
83
- try:
84
- logging.info("Trying final download method with curl...")
85
- result = subprocess.run(['curl', '-L', url, '--output', destination],
86
- check=True,
87
- capture_output=True,
88
- text=True)
89
- logging.info(f"curl download successful")
90
- return True
91
- except subprocess.CalledProcessError as e:
92
- logging.error(f"curl download failed: {e.stderr}")
93
- return False
94
-
95
- # Download required model files if they don't exist
96
- def download_required_files():
97
- # URLs for model files
98
- wav_tokenizer_config_url = "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
99
- wav_tokenizer_model_url = "https://huggingface.co/novateur/WavTokenizer-small-speech-320token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
100
-
101
- # Download config file if it doesn't exist
102
- if not os.path.exists(wav_tokenizer_config_path):
103
- success = download_file(wav_tokenizer_config_url, wav_tokenizer_config_path)
104
- if not success:
105
- raise RuntimeError(f"Failed to download config file from {wav_tokenizer_config_url}")
106
-
107
- # Download model file if it doesn't exist
108
- if not os.path.exists(wav_tokenizer_model_path):
109
- success = download_file(wav_tokenizer_model_url, wav_tokenizer_model_path)
110
  if not success:
111
- # Try alternate source for the model file (from Google Drive)
112
  try:
113
- logging.info("Installing gdown package...")
114
- subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"])
115
- import gdown
116
- logging.info("gdown installed successfully")
 
 
 
117
 
118
- gdrive_url = "https://drive.google.com/uc?id=1-ASeEkrn4HY49yZWHTASgfGFNXdVnLTt"
119
- logging.info(f"Trying alternate source for model file: {gdrive_url}")
 
 
 
 
 
 
 
 
120
 
121
- gdown.download(gdrive_url, wav_tokenizer_model_path, quiet=False)
122
- if os.path.exists(wav_tokenizer_model_path):
123
- logging.info(f"Successfully downloaded model file using gdown")
124
- else:
125
- raise RuntimeError("File not found after gdown download")
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
- logging.error(f"Failed to download model file using gdown: {str(e)}")
128
- raise RuntimeError(f"Failed to download model file from any source")
 
 
 
 
 
 
 
 
 
 
129
 
130
- # Function to verify if required files exist
131
- def verify_required_files():
132
- if not os.path.exists(wav_tokenizer_config_path):
133
- raise FileNotFoundError(f"Config file not found at {wav_tokenizer_config_path}")
134
-
135
- if not os.path.exists(wav_tokenizer_model_path):
136
- raise FileNotFoundError(f"Model file not found at {wav_tokenizer_model_path}")
137
 
138
- logging.info("All required files verified")
139
-
140
- # Define TextToSpeech class based on the working Colab code
141
- class TextToSpeech:
142
- def __init__(self):
143
- logging.info("Initializing TextToSpeech class...")
144
-
145
- try:
146
- # Import the AudioTokenizerV2 class from yarngpt
147
- from yarngpt.audiotokenizer import AudioTokenizerV2
148
- logging.info("Successfully imported AudioTokenizerV2 class")
149
- except ImportError as e:
150
- logging.error(f"Failed to import AudioTokenizerV2 class: {str(e)}")
151
- sys.exit(1)
152
-
153
- # Download required files
154
- download_required_files()
155
 
156
- # Verify files exist
157
- verify_required_files()
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Detect device
160
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
161
- logging.info(f"Using device: {self.device}")
162
 
163
- # Initialize audio tokenizer
164
  try:
165
- self.audio_tokenizer = AudioTokenizerV2(
166
- tokenizer_path,
167
- wav_tokenizer_model_path,
168
- wav_tokenizer_config_path
 
 
 
 
 
 
 
169
  )
170
- logging.info("Audio tokenizer initialized successfully")
171
- except Exception as e:
172
- logging.error(f"Failed to initialize audio tokenizer: {str(e)}")
173
- raise
174
-
175
- # Load model
176
- try:
177
- self.model = AutoModelForCausalLM.from_pretrained(
178
  tokenizer_path,
179
  torch_dtype="auto"
180
- ).to(self.audio_tokenizer.device)
181
- logging.info("Model loaded successfully")
182
- except Exception as e:
183
- logging.error(f"Failed to load model: {str(e)}")
184
- raise
185
-
186
- def tts(self, text, output_file, accent="nigerian", speed=1.0):
187
- """
188
- Generate Nigerian-accented speech from text
189
-
190
- Args:
191
- text: Text to convert to speech
192
- output_file: Path to save the audio file
193
- accent: Accent to use (maps to a specific speaker)
194
- speed: Speed multiplier (not currently implemented)
195
-
196
- Returns:
197
- Path to generated audio file
198
- """
199
- logging.info(f"Generating speech for text: '{text[:50]}...'")
200
-
201
- # Map accent to speaker name
202
- speaker_mapping = {
203
- "nigerian": "tayo",
204
- "yoruba": "idera",
205
- "igbo": "chidi",
206
- "hausa": "aminu",
207
- "default": "tayo"
208
- }
209
-
210
- speaker = speaker_mapping.get(accent.lower(), speaker_mapping["default"])
211
- logging.info(f"Using speaker: {speaker}")
212
-
213
- try:
214
- # Create prompt
215
- prompt = self.audio_tokenizer.create_prompt(text, lang="english", speaker_name=speaker)
216
- input_ids = self.audio_tokenizer.tokenize_prompt(prompt)
217
 
218
- # Generate output
219
- output = self.model.generate(
220
- input_ids=input_ids,
221
- temperature=0.1,
222
- repetition_penalty=1.1,
223
- max_length=4000,
224
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- # Convert to audio
227
- codes = self.audio_tokenizer.get_codes(output)
228
- audio = self.audio_tokenizer.get_audio(codes)
 
229
 
230
- # Save audio file
231
- torchaudio.save(output_file, audio, sample_rate=24000)
232
- logging.info(f"Audio saved to {output_file}")
233
 
234
- return output_file
235
- except Exception as e:
236
- logging.error(f"Error in TTS generation: {str(e)}")
237
- traceback.print_exc()
238
- raise
239
-
240
- # Try to initialize TTS engine, but allow app to start even if it fails
241
- tts_engine = None
242
- try:
243
- logging.info("Starting TTS engine initialization...")
244
- tts_engine = TextToSpeech()
245
- logging.info("TTS engine initialized successfully")
246
- except Exception as e:
247
- logging.error(f"Failed to initialize TTS engine: {str(e)}")
248
- print(traceback.format_exc())
249
-
250
- # Create output directory if it doesn't exist
251
- output_dir = Path("./output")
252
- output_dir.mkdir(exist_ok=True)
253
-
254
- # Model for the TTS request
255
- class TTSRequest(BaseModel):
256
- text: str
257
- accent: str = "nigerian" # Default accent
258
- speed: float = 1.0 # Default speed
259
 
260
  # Health check endpoint
261
  @app.get("/")
262
- def health_check():
 
 
 
263
  return {
264
- "status": "ok",
265
- "tts_engine_loaded": tts_engine is not None,
266
- "device": tts_engine.device if tts_engine else "not available",
267
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
268
  }
269
 
270
- # Text-to-speech endpoint
271
  @app.post("/tts")
272
- async def text_to_speech(request: TTSRequest):
 
 
273
  if tts_engine is None:
274
- logging.error("TTS engine not initialized")
275
- raise HTTPException(status_code=500, detail="TTS engine not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  try:
278
- # Generate a unique filename
279
- filename = f"{uuid.uuid4()}.wav"
280
- output_path = output_dir / filename
281
 
282
- # Log the request
283
- logging.info(f"Processing TTS request: text='{request.text[:50]}...', accent={request.accent}")
284
 
285
- # Generate speech
286
- tts_engine.tts(
287
- text=request.text,
288
- output_file=str(output_path),
289
- accent=request.accent,
290
- speed=request.speed
291
- )
292
 
293
- # Check if file was created
294
- if not output_path.exists():
295
- logging.error(f"Output file was not created: {output_path}")
296
- raise HTTPException(status_code=500, detail="Failed to generate audio file")
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- # Return the audio file
299
- logging.info(f"Successfully generated audio: {output_path}")
300
- return FileResponse(
301
- path=output_path,
302
- media_type="audio/wav",
303
- filename=filename
304
- )
305
-
306
  except Exception as e:
307
- logging.error(f"Error in text_to_speech: {str(e)}")
308
- print(traceback.format_exc())
309
- raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
310
 
311
- # Cleanup old files (run this periodically)
312
- @app.on_event("startup")
313
- async def cleanup_old_files():
 
 
 
 
 
 
 
 
 
314
  try:
315
- # Delete files older than 1 hour
316
- current_time = time.time()
317
- for file_path in output_dir.glob("*.wav"):
318
- if current_time - file_path.stat().st_mtime > 3600: # 1 hour
319
- file_path.unlink()
320
- logging.info(f"Deleted old file: {file_path}")
 
 
 
 
 
 
 
 
 
 
 
321
  except Exception as e:
322
- logging.error(f"Error during cleanup: {str(e)}")
323
 
324
- # Custom exception handler
325
  @app.exception_handler(Exception)
326
  async def global_exception_handler(request: Request, exc: Exception):
327
- logging.error(f"Unhandled exception: {str(exc)}")
328
- print(traceback.format_exc())
329
  return JSONResponse(
330
  status_code=500,
331
- content={"detail": f"Internal server error: {str(exc)}"}
332
  )
333
 
334
- # Start server if running as a script
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  if __name__ == "__main__":
336
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  import traceback
7
  import requests
8
  import subprocess
 
9
  from pathlib import Path
10
+ from datetime import datetime, timedelta
11
+ from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
12
  from fastapi.responses import FileResponse, JSONResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
  from pydantic import BaseModel
15
  import uvicorn
 
 
 
16
 
17
  # Configure logging
18
  logging.basicConfig(
19
  level=logging.INFO,
20
+ format="%(asctime)s | %(levelname)s | %(message)s",
21
+ handlers=[logging.StreamHandler()]
22
  )
23
+ logger = logging.getLogger(__name__)
24
 
25
+ # Create start-up log entry
26
+ logger.info(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
27
 
28
+ # Create output directory for audio files
29
+ os.makedirs("audio_files", exist_ok=True)
30
 
31
+ # Initialize FastAPI app
32
+ app = FastAPI(title="Nigerian Text-to-Speech API")
33
 
34
+ # Add CORS middleware to allow cross-origin requests (for Streamlit/cURL)
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"], # Allows all origins
38
+ allow_credentials=True,
39
+ allow_methods=["*"], # Allows all methods
40
+ allow_headers=["*"], # Allows all headers
41
+ )
42
 
43
+ # Input validation models
44
+ class TTSRequest(BaseModel):
45
+ text: str
46
+ accent: str = "nigerian" # For backward compatibility
47
+ voice: str = None # New parameter (will override accent if provided)
48
+ language: str = "english" # Default language
49
 
50
+ class TTSResponse(BaseModel):
51
+ audio_url: str
52
+ audio_base64: str = None # Base64-encoded audio (optional)
53
+ text: str
54
+ voice: str
55
+ language: str
56
 
57
+ # Define available voices and mapping
58
+ AVAILABLE_VOICES = {
59
+ "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
60
+ "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
61
+ }
62
 
63
+ ACCENT_TO_VOICE = {
64
+ "nigerian": "tayo",
65
+ "yoruba": "idera",
66
+ "igbo": "emma",
67
+ "hausa": "umar"
68
+ }
69
+
70
+ AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
71
+
72
+ # Initialize global variables for model components
73
+ model = None
74
+ audio_tokenizer = None
75
+ tts_engine = None
76
+
77
+ def download_required_files():
78
+ """
79
+ Download model files from multiple sources with fallback mechanisms.
80
+ """
81
+ files_to_download = [
82
+ {
83
+ "url": "https://huggingface.co/novateur/WavTokenizer-small-speech-320token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt",
84
+ "output_path": "wavtokenizer_large_speech_320_24k.ckpt",
85
+ "gdrive_id": "1-6uQcVGonAdmAiazJ8YEQBHoGzbKXrsW" # Backup Google Drive ID
86
+ },
87
+ {
88
+ "url": "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/config.json",
89
+ "output_path": "saheedniyi_YarnGPT2/config.json",
90
+ "gdrive_id": None
91
+ },
92
+ {
93
+ "url": "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/tokenizer_config.json",
94
+ "output_path": "saheedniyi_YarnGPT2/tokenizer_config.json",
95
+ "gdrive_id": None
96
+ },
97
+ {
98
+ "url": "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/pytorch_model.bin",
99
+ "output_path": "saheedniyi_YarnGPT2/pytorch_model.bin",
100
+ "gdrive_id": "1-3KU78OGUyPxtjYPSITx6N3vj46aOeFu" # Backup Google Drive ID
101
+ },
102
+ {
103
+ "url": "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml",
104
+ "output_path": "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml",
105
+ "gdrive_id": None
106
+ }
107
+ ]
108
+
109
+ # Prepare directory for model files
110
+ os.makedirs("saheedniyi_YarnGPT2", exist_ok=True)
111
+
112
+ for file_info in files_to_download:
113
+ output_path = file_info["output_path"]
114
 
115
+ # Skip if file already exists
116
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
117
+ logger.info(f"File already exists: {output_path}")
118
+ continue
119
+
120
+ logger.info(f"Downloading file: {output_path}")
121
 
122
+ # Try different download methods
123
+ success = False
 
 
124
 
125
+ # Method 1: Direct requests download
126
  try:
127
+ logger.info(f"Trying direct download with requests: {file_info['url']}")
128
+ response = requests.get(file_info['url'], stream=True, timeout=30)
129
+ if response.status_code == 200:
130
+ with open(output_path, 'wb') as f:
131
+ for chunk in response.iter_content(chunk_size=8192):
132
+ f.write(chunk)
133
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
134
+ logger.info(f"Successfully downloaded via requests: {output_path}")
135
+ success = True
136
+ except Exception as e:
137
+ logger.error(f"Failed to download with requests: {str(e)}")
138
+
139
+ # Method 2: wget if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  if not success:
 
141
  try:
142
+ logger.info(f"Trying download with wget: {file_info['url']}")
143
+ subprocess.run(["wget", file_info['url'], "-O", output_path], check=True)
144
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
145
+ logger.info(f"Successfully downloaded via wget: {output_path}")
146
+ success = True
147
+ except Exception as e:
148
+ logger.error(f"Failed to download with wget: {str(e)}")
149
 
150
+ # Method 3: curl if available
151
+ if not success:
152
+ try:
153
+ logger.info(f"Trying download with curl: {file_info['url']}")
154
+ subprocess.run(["curl", "-L", file_info['url'], "-o", output_path], check=True)
155
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
156
+ logger.info(f"Successfully downloaded via curl: {output_path}")
157
+ success = True
158
+ except Exception as e:
159
+ logger.error(f"Failed to download with curl: {str(e)}")
160
 
161
+ # Method 4: gdown from Google Drive (if ID is provided)
162
+ if not success and file_info["gdrive_id"]:
163
+ try:
164
+ logger.info(f"Trying download from Google Drive: {file_info['gdrive_id']}")
165
+ # Install gdown if not already installed
166
+ try:
167
+ subprocess.run([sys.executable, "-m", "pip", "install", "gdown", "--quiet"], check=True)
168
+ import gdown
169
+ gdown.download(id=file_info["gdrive_id"], output=output_path, quiet=False)
170
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
171
+ logger.info(f"Successfully downloaded via gdown: {output_path}")
172
+ success = True
173
+ except Exception as e:
174
+ logger.error(f"Failed to install or use gdown: {str(e)}")
175
  except Exception as e:
176
+ logger.error(f"Failed to download from Google Drive: {str(e)}")
177
+
178
+ if not success:
179
+ logger.error(f"All download methods failed for: {output_path}")
180
+ raise FileNotFoundError(f"Failed to download required file: {output_path}")
181
+
182
+ # Verify all files were downloaded
183
+ for file_info in files_to_download:
184
+ if not os.path.exists(file_info["output_path"]) or os.path.getsize(file_info["output_path"]) == 0:
185
+ raise FileNotFoundError(f"Required file missing or empty: {file_info['output_path']}")
186
+
187
+ logger.info("All required files downloaded successfully!")
188
 
189
+ def load_tts_engine():
190
+ """
191
+ Load the TTS engine and models with explicit PyTorch version handling.
192
+ """
193
+ global model, audio_tokenizer, tts_engine
 
 
194
 
195
+ try:
196
+ # Only import these modules when needed to avoid startup errors
197
+ import torch
198
+ import torchaudio
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # Apply monkey patch for PyTorch 2.6+ compatibility
201
+ if hasattr(torch, '__version__') and torch.__version__.startswith('2.6'):
202
+ logger.info(f"Detected PyTorch {torch.__version__}, applying load function patch for weights_only")
203
+ original_torch_load = torch.load
204
+
205
+ def patched_torch_load(*args, **kwargs):
206
+ # Add weights_only=False if not explicitly specified
207
+ if 'weights_only' not in kwargs:
208
+ kwargs['weights_only'] = False
209
+ return original_torch_load(*args, **kwargs)
210
+
211
+ torch.load = patched_torch_load
212
 
213
+ # Now import other dependencies
214
+ from transformers import AutoModelForCausalLM
 
215
 
 
216
  try:
217
+ # Try to import the WavTokenizer for yarngpt
218
+ from yarngpt.audiotokenizer import AudioTokenizerV2
219
+
220
+ # Model configuration
221
+ tokenizer_path = "saheedniyi_YarnGPT2"
222
+ wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
223
+ wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
224
+
225
+ logger.info("Loading YarnGPT model and tokenizer...")
226
+ audio_tokenizer = AudioTokenizerV2(
227
+ tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
228
  )
229
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
230
  tokenizer_path,
231
  torch_dtype="auto"
232
+ ).to(audio_tokenizer.device)
233
+ logger.info("YarnGPT model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ class TextToSpeech:
236
+ def __init__(self):
237
+ self.audio_tokenizer = audio_tokenizer
238
+ self.model = model
239
+
240
+ def generate_speech(self, text, language="english", speaker_name="tayo"):
241
+ # Create prompt and generate audio
242
+ prompt = self.audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
243
+ input_ids = self.audio_tokenizer.tokenize_prompt(prompt)
244
+
245
+ output = self.model.generate(
246
+ input_ids=input_ids,
247
+ temperature=0.1,
248
+ repetition_penalty=1.1,
249
+ max_length=4000,
250
+ )
251
+
252
+ codes = self.audio_tokenizer.get_codes(output)
253
+ audio = self.audio_tokenizer.get_audio(codes)
254
+ return audio
255
 
256
+ # Initialize TTS engine
257
+ tts_engine = TextToSpeech()
258
+ logger.info("TTS engine initialized successfully!")
259
+ return True
260
 
261
+ except ImportError:
262
+ logger.error("Failed to import yarngpt modules. Make sure the yarngpt package is installed.")
263
+ return False
264
 
265
+ except Exception as e:
266
+ logger.error(f"Error initializing TTS engine: {str(e)}")
267
+ logger.error(traceback.format_exc())
268
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  # Health check endpoint
271
  @app.get("/")
272
+ async def root():
273
+ """API health check and info"""
274
+ status = "ok" if tts_engine is not None else "model_loading_failed"
275
+
276
  return {
277
+ "status": status,
278
+ "message": "Nigerian TTS API is running",
279
+ "available_languages": AVAILABLE_LANGUAGES,
280
+ "available_voices": AVAILABLE_VOICES,
281
+ "accent_mapping": ACCENT_TO_VOICE
282
  }
283
 
284
+ # TTS endpoint
285
  @app.post("/tts")
286
+ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
287
+ """Convert text to Nigerian-accented speech"""
288
+ # Check if TTS engine is loaded
289
  if tts_engine is None:
290
+ raise HTTPException(status_code=503, detail="TTS engine is not initialized. Please try again later.")
291
+
292
+ # Determine voice based on accent or explicitly provided voice
293
+ voice = request.voice
294
+ if voice is None:
295
+ accent = request.accent.lower() if request.accent else "nigerian"
296
+ voice = ACCENT_TO_VOICE.get(accent, "tayo") # Default to tayo if accent not recognized
297
+
298
+ # Validate language
299
+ language = request.language.lower()
300
+ if language not in AVAILABLE_LANGUAGES:
301
+ raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}")
302
+
303
+ # Validate voice - combine all available voices
304
+ all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
305
+ if voice not in all_voices:
306
+ raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}")
307
+
308
+ # Generate unique filename
309
+ audio_id = str(uuid.uuid4())
310
+ output_path = f"audio_files/{audio_id}.wav"
311
 
312
  try:
313
+ # Generate audio using the TTS engine
314
+ audio = tts_engine.generate_speech(request.text, language=language, speaker_name=voice)
 
315
 
316
+ # Import torchaudio here to avoid startup issues
317
+ import torchaudio
318
 
319
+ # Save audio file
320
+ torchaudio.save(output_path, audio, sample_rate=24000)
 
 
 
 
 
321
 
322
+ # Generate base64 representation for direct embedding
323
+ import base64
324
+ with open(output_path, "rb") as audio_file:
325
+ audio_bytes = audio_file.read()
326
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
327
+
328
+ # Add task to clean up old files
329
+ background_tasks.add_task(cleanup_old_files)
330
+
331
+ return {
332
+ "audio_url": f"/audio/{audio_id}.wav",
333
+ "audio_base64": audio_base64,
334
+ "text": request.text,
335
+ "voice": voice,
336
+ "language": language
337
+ }
338
 
 
 
 
 
 
 
 
 
339
  except Exception as e:
340
+ logger.error(f"Error generating audio: {str(e)}")
341
+ logger.error(traceback.format_exc())
342
+ raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
343
 
344
+ # Serve audio files
345
+ @app.get("/audio/{filename}")
346
+ async def get_audio(filename: str):
347
+ """Serve audio files"""
348
+ file_path = f"audio_files/{filename}"
349
+ if not os.path.exists(file_path):
350
+ raise HTTPException(status_code=404, detail="Audio file not found")
351
+ return FileResponse(file_path, media_type="audio/wav")
352
+
353
+ # Cleanup function to remove old files
354
+ def cleanup_old_files():
355
+ """Delete audio files older than 6 hours to manage disk space"""
356
  try:
357
+ now = datetime.now()
358
+ audio_dir = "audio_files"
359
+
360
+ if not os.path.exists(audio_dir):
361
+ return
362
+
363
+ for filename in os.listdir(audio_dir):
364
+ if not filename.endswith(".wav"):
365
+ continue
366
+
367
+ file_path = os.path.join(audio_dir, filename)
368
+ file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
369
+
370
+ # Delete files older than 6 hours
371
+ if now - file_mod_time > timedelta(hours=6):
372
+ os.remove(file_path)
373
+ logger.info(f"Deleted old audio file: {filename}")
374
  except Exception as e:
375
+ logger.error(f"Error cleaning up old files: {e}")
376
 
377
+ # Custom exception handler for better error responses
378
  @app.exception_handler(Exception)
379
  async def global_exception_handler(request: Request, exc: Exception):
380
+ logger.error(f"Unhandled exception: {str(exc)}")
381
+ logger.error(traceback.format_exc())
382
  return JSONResponse(
383
  status_code=500,
384
+ content={"detail": f"An unexpected error occurred: {str(exc)}"}
385
  )
386
 
387
+ # Initialize on startup
388
+ @app.on_event("startup")
389
+ async def startup_event():
390
+ # Download required files first
391
+ try:
392
+ download_required_files()
393
+ except Exception as e:
394
+ logger.error(f"Failed to download required files: {str(e)}")
395
+ logger.error(traceback.format_exc())
396
+ return
397
+
398
+ # Then try to load the TTS engine
399
+ try:
400
+ # Install yarngpt first
401
+ subprocess.run([sys.executable, "-m", "pip", "install", "git+https://github.com/saheedniyi02/yarngpt.git", "--quiet"], check=True)
402
+ logger.info("Successfully installed yarngpt package")
403
+
404
+ # Load TTS engine
405
+ success = load_tts_engine()
406
+ if not success:
407
+ logger.error("Failed to initialize TTS engine")
408
+ except Exception as e:
409
+ logger.error(f"Failed to initialize app: {str(e)}")
410
+ logger.error(traceback.format_exc())
411
+
412
+ # Main entry point
413
  if __name__ == "__main__":
414
  uvicorn.run(app, host="0.0.0.0", port=7860)