Hameed13 commited on
Commit
ea51340
·
verified ·
1 Parent(s): 91499fa

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +268 -130
main.py CHANGED
@@ -1,41 +1,47 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
2
- from fastapi.responses import FileResponse, JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel
5
  import os
6
  import sys
7
  import time
8
  import uuid
9
- import base64
10
- import datetime
11
  import logging
12
  import traceback
 
 
 
13
  from typing import Optional
14
  import torch
15
- import soundfile as sf
 
 
 
 
 
 
 
16
 
17
  # Configure logging
18
- logging.basicConfig(level=logging.INFO,
19
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
20
  logger = logging.getLogger(__name__)
21
 
22
- # Set environment variable to handle PortAudio issues
23
- os.environ["OUTETTS_NO_PORTAUDIO"] = "1"
24
-
25
- # Import the TextToSpeech class from generate.py
26
- try:
27
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
28
- from generate import TextToSpeech
29
- logger.info("Successfully imported TextToSpeech class from generate.py")
30
- except ImportError as e:
31
- logger.error(f"Failed to import TextToSpeech class: {e}")
32
- traceback.print_exc()
33
- sys.exit(1)
34
 
35
- # Create the FastAPI app
36
- app = FastAPI(title="Nigerian Text-to-Speech API")
 
 
 
 
37
 
38
- # Add CORS middleware to allow cross-origin requests
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
@@ -44,165 +50,297 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
47
- # YarnGPT TTS configuration
48
- MODEL_CONFIG = {
49
- "model_name_or_path": "yarngpt/yarn-tts-demo",
50
- "processor_name_or_path": "yarngpt/yarn-tts-demo"
51
- }
 
 
52
 
53
- # Available voices and languages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  AVAILABLE_VOICES = {
55
  "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
56
  "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
57
  }
 
 
 
 
 
 
 
 
58
  AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
59
 
60
- # Initialize YarnGPT
61
- yarngpt = None
 
 
62
 
63
- class TTSRequest(BaseModel):
64
- text: str
65
- language: str = "english"
66
- voice: str = "idera"
67
- speed: float = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Define the directory for storing generated audio files
70
- AUDIO_DIR = "audio_files"
71
- os.makedirs(AUDIO_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def initialize_yarngpt():
74
- """Initialize the YarnGPT TTS model with proper error handling."""
75
- global yarngpt
76
 
77
  try:
78
- logger.info("Initializing YarnGPT TTS model...")
 
 
79
 
80
- # Create TextToSpeech instance with option to disable playback
81
- yarngpt = TextToSpeech(
82
- model_name_or_path=MODEL_CONFIG["model_name_or_path"],
83
- processor_name_or_path=MODEL_CONFIG["processor_name_or_path"],
84
- disable_playback=True # Disable playback to avoid PortAudio issues
 
 
 
 
 
 
 
 
 
85
  )
86
 
87
- logger.info("YarnGPT TTS model initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  return True
89
-
90
  except Exception as e:
91
- logger.error(f"Failed to initialize YarnGPT: {str(e)}")
92
- traceback.print_exc()
93
  return False
94
 
95
  def cleanup_old_files(max_age_hours=6):
96
  """Remove audio files older than the specified hours."""
97
  try:
98
- now = time.time()
99
- count = 0
100
-
101
- for filename in os.listdir(AUDIO_DIR):
102
- file_path = os.path.join(AUDIO_DIR, filename)
103
- if os.path.isfile(file_path):
104
- if now - os.path.getmtime(file_path) > max_age_hours * 3600:
105
  os.remove(file_path)
106
- count += 1
107
-
108
- logger.info(f"Cleaned up {count} old audio files")
109
  except Exception as e:
110
- logger.error(f"Error during file cleanup: {str(e)}")
111
 
112
  @app.on_event("startup")
113
  async def startup_event():
114
- """Initialize required components on startup."""
115
- success = initialize_yarngpt()
116
- if not success:
117
- logger.warning("YarnGPT failed to initialize. The API may not function correctly.")
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @app.get("/")
120
- async def health_check():
121
- """API health check endpoint."""
122
- current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
123
- yarngpt_status = "initialized" if yarngpt is not None else "not initialized"
124
-
125
  return {
126
- "status": "online",
127
- "timestamp": current_time,
128
- "yarngpt_status": yarngpt_status,
129
  "available_languages": AVAILABLE_LANGUAGES,
130
- "available_voices": AVAILABLE_VOICES
 
131
  }
132
 
133
- @app.post("/tts")
134
  async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
135
- """Convert text to speech using YarnGPT model."""
136
- if yarngpt is None:
137
- logger.error("YarnGPT model not initialized")
138
- success = initialize_yarngpt()
139
- if not success:
140
- raise HTTPException(status_code=500, detail="YarnGPT model initialization failed. Please check logs.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  try:
143
- # Generate a unique filename
144
  audio_id = str(uuid.uuid4())
145
- output_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav")
146
 
147
- logger.info(f"Processing TTS request: '{request.text[:50]}...' with voice '{request.voice}' and language '{request.language}'")
 
 
 
 
 
 
148
 
149
- # Create prompt from voice and language
150
- # This adapts to the colab-style API even though we're using a different backend
151
- accent = request.language if request.language in ["nigerian"] else "nigerian"
152
 
153
- # Generate speech
154
- try:
155
- audio_data, sample_rate = yarngpt.tts(
156
- text=request.text,
157
- accent=accent,
158
- save_path=output_path,
159
- speed=request.speed,
160
- get_array=True
161
- )
162
-
163
- # Convert audio to base64
164
- sf.write(output_path, audio_data, sample_rate)
165
- with open(output_path, "rb") as audio_file:
166
- audio_bytes = audio_file.read()
167
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
168
-
169
- except Exception as e:
170
- logger.error(f"Error in speech generation: {str(e)}")
171
- traceback.print_exc()
172
- raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
173
 
174
- # Schedule cleanup for old files
175
  background_tasks.add_task(cleanup_old_files)
176
 
177
- # Check if file exists
178
- if not os.path.exists(output_path):
179
- logger.error(f"Output file was not created: {output_path}")
180
- raise HTTPException(status_code=500, detail="Failed to create audio file")
181
-
182
- logger.info(f"Successfully generated audio file: {audio_id}.wav")
183
-
184
- # Return both file URL and base64 data for compatibility with both APIs
185
  return {
186
- "audio_base64": audio_base64,
187
  "audio_url": f"/audio/{audio_id}.wav",
 
188
  "text": request.text,
189
- "voice": request.voice,
190
- "language": request.language
191
  }
192
 
193
  except Exception as e:
194
- logger.error(f"Error in TTS processing: {str(e)}")
195
- traceback.print_exc()
196
- raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
 
 
 
197
 
198
- # File serving endpoint (for backward compatibility)
199
  @app.get("/audio/{filename}")
200
  async def get_audio(filename: str):
201
- file_path = os.path.join(AUDIO_DIR, filename)
 
202
  if not os.path.exists(file_path):
203
- raise HTTPException(status_code=404, detail="Audio file not found")
 
 
 
204
  return FileResponse(file_path, media_type="audio/wav")
205
 
 
 
 
 
 
 
 
 
 
 
206
  if __name__ == "__main__":
207
- import uvicorn
208
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
1
  import os
2
  import sys
3
  import time
4
  import uuid
 
 
5
  import logging
6
  import traceback
7
+ import requests
8
+ from pathlib import Path
9
+ from datetime import datetime, timedelta
10
  from typing import Optional
11
  import torch
12
+ import torchaudio
13
+
14
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
15
+ from fastapi.responses import FileResponse, JSONResponse
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel
18
+ import uvicorn
19
+ from huggingface_hub import hf_hub_download
20
 
21
  # Configure logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
25
+ handlers=[
26
+ logging.StreamHandler(),
27
+ logging.FileHandler("app.log")
28
+ ]
29
+ )
30
  logger = logging.getLogger(__name__)
31
 
32
+ # Create necessary directories
33
+ REQUIRED_DIRS = ["audio_files", "models", "saheedniyi_YarnGPT2"]
34
+ for directory in REQUIRED_DIRS:
35
+ os.makedirs(directory, exist_ok=True)
 
 
 
 
 
 
 
 
36
 
37
+ # Initialize FastAPI app
38
+ app = FastAPI(
39
+ title="Nigerian Text-to-Speech API",
40
+ description="Convert text to Nigerian-accented speech using YarnGPT",
41
+ version="1.0.0"
42
+ )
43
 
44
+ # Add CORS middleware
45
  app.add_middleware(
46
  CORSMiddleware,
47
  allow_origins=["*"],
 
50
  allow_headers=["*"],
51
  )
52
 
53
+ # Input validation models
54
+ class TTSRequest(BaseModel):
55
+ text: str
56
+ accent: str = "nigerian"
57
+ voice: str = None
58
+ language: str = "english"
59
+ speed: float = 1.0
60
 
61
+ class Config:
62
+ schema_extra = {
63
+ "example": {
64
+ "text": "Welcome to Nigeria, the giant of Africa.",
65
+ "accent": "nigerian",
66
+ "voice": "tayo",
67
+ "language": "english",
68
+ "speed": 1.0
69
+ }
70
+ }
71
+
72
+ class TTSResponse(BaseModel):
73
+ audio_url: str
74
+ audio_base64: str = None
75
+ text: str
76
+ voice: str
77
+ language: str
78
+
79
+ # Define available voices and languages
80
  AVAILABLE_VOICES = {
81
  "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
82
  "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
83
  }
84
+
85
+ ACCENT_TO_VOICE = {
86
+ "nigerian": "tayo",
87
+ "yoruba": "idera",
88
+ "igbo": "emma",
89
+ "hausa": "umar"
90
+ }
91
+
92
  AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
93
 
94
+ # Initialize global variables for model components
95
+ model = None
96
+ audio_tokenizer = None
97
+ tts_engine = None
98
 
99
+ def download_model_files():
100
+ """Download required model files from Hugging Face Hub."""
101
+ files_to_download = [
102
+ {
103
+ "repo_id": "novateur/WavTokenizer-small-speech-320token",
104
+ "filename": "wavtokenizer_large_speech_320_24k.ckpt",
105
+ "output_path": "models/wavtokenizer_large_speech_320_24k.ckpt"
106
+ },
107
+ {
108
+ "repo_id": "saheedniyi/YarnGPT2",
109
+ "filename": "config.json",
110
+ "output_path": "saheedniyi_YarnGPT2/config.json"
111
+ },
112
+ {
113
+ "repo_id": "saheedniyi/YarnGPT2",
114
+ "filename": "tokenizer_config.json",
115
+ "output_path": "saheedniyi_YarnGPT2/tokenizer_config.json"
116
+ },
117
+ {
118
+ "repo_id": "saheedniyi/YarnGPT2",
119
+ "filename": "pytorch_model.bin",
120
+ "output_path": "saheedniyi_YarnGPT2/pytorch_model.bin"
121
+ }
122
+ ]
123
 
124
+ for file_info in files_to_download:
125
+ try:
126
+ if not os.path.exists(file_info["output_path"]):
127
+ logger.info(f"Downloading {file_info['filename']} from {file_info['repo_id']}")
128
+ hf_hub_download(
129
+ repo_id=file_info["repo_id"],
130
+ filename=file_info["filename"],
131
+ local_dir=".",
132
+ local_dir_use_symlinks=False
133
+ )
134
+ logger.info(f"Successfully downloaded {file_info['filename']}")
135
+ else:
136
+ logger.info(f"File already exists: {file_info['output_path']}")
137
+ except Exception as e:
138
+ logger.error(f"Error downloading {file_info['filename']}: {str(e)}")
139
+ raise
140
 
141
+ def load_tts_engine():
142
+ """Initialize the TTS engine with proper error handling."""
143
+ global model, audio_tokenizer, tts_engine
144
 
145
  try:
146
+ # Import required modules
147
+ from transformers import AutoModelForCausalLM
148
+ from yarngpt.audiotokenizer import AudioTokenizerV2
149
 
150
+ # Set device
151
+ device = "cuda" if torch.cuda.is_available() else "cpu"
152
+ logger.info(f"Using device: {device}")
153
+
154
+ # Load tokenizer and model
155
+ tokenizer_path = "saheedniyi_YarnGPT2"
156
+ wav_tokenizer_config = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
157
+ wav_tokenizer_model = "models/wavtokenizer_large_speech_320_24k.ckpt"
158
+
159
+ logger.info("Loading audio tokenizer...")
160
+ audio_tokenizer = AudioTokenizerV2(
161
+ tokenizer_path,
162
+ wav_tokenizer_model,
163
+ wav_tokenizer_config
164
  )
165
 
166
+ logger.info("Loading model...")
167
+ model = AutoModelForCausalLM.from_pretrained(
168
+ tokenizer_path,
169
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
170
+ ).to(device)
171
+
172
+ class TextToSpeech:
173
+ def __init__(self):
174
+ self.audio_tokenizer = audio_tokenizer
175
+ self.model = model
176
+ self.device = device
177
+
178
+ def generate_speech(self, text, language="english", speaker_name="tayo", speed=1.0):
179
+ prompt = self.audio_tokenizer.create_prompt(
180
+ text,
181
+ lang=language,
182
+ speaker_name=speaker_name
183
+ )
184
+ input_ids = self.audio_tokenizer.tokenize_prompt(prompt)
185
+
186
+ with torch.no_grad():
187
+ output = self.model.generate(
188
+ input_ids=input_ids,
189
+ temperature=0.1,
190
+ repetition_penalty=1.1,
191
+ max_length=4000,
192
+ do_sample=True,
193
+ top_k=50,
194
+ top_p=0.95
195
+ )
196
+
197
+ codes = self.audio_tokenizer.get_codes(output)
198
+ audio = self.audio_tokenizer.get_audio(codes)
199
+
200
+ if speed != 1.0:
201
+ import librosa
202
+ audio = librosa.effects.time_stretch(audio.numpy().squeeze(), rate=speed)
203
+ audio = torch.from_numpy(audio).unsqueeze(0)
204
+
205
+ return audio
206
+
207
+ tts_engine = TextToSpeech()
208
+ logger.info("TTS engine initialized successfully!")
209
  return True
210
+
211
  except Exception as e:
212
+ logger.error(f"Error initializing TTS engine: {str(e)}")
213
+ logger.error(traceback.format_exc())
214
  return False
215
 
216
  def cleanup_old_files(max_age_hours=6):
217
  """Remove audio files older than the specified hours."""
218
  try:
219
+ now = datetime.now()
220
+ for filename in os.listdir("audio_files"):
221
+ if filename.endswith(".wav"):
222
+ file_path = os.path.join("audio_files", filename)
223
+ if now - datetime.fromtimestamp(os.path.getmtime(file_path)) > timedelta(hours=max_age_hours):
 
 
224
  os.remove(file_path)
225
+ logger.info(f"Deleted old audio file: {filename}")
 
 
226
  except Exception as e:
227
+ logger.error(f"Error cleaning up files: {str(e)}")
228
 
229
  @app.on_event("startup")
230
  async def startup_event():
231
+ """Initialize the application on startup."""
232
+ try:
233
+ # Download model files
234
+ download_model_files()
235
+
236
+ # Initialize TTS engine
237
+ success = load_tts_engine()
238
+ if not success:
239
+ logger.error("Failed to initialize TTS engine")
240
+ raise RuntimeError("TTS engine initialization failed")
241
+
242
+ except Exception as e:
243
+ logger.error(f"Startup failed: {str(e)}")
244
+ logger.error(traceback.format_exc())
245
+ raise
246
 
247
  @app.get("/")
248
+ async def root():
249
+ """API health check and info endpoint."""
 
 
 
250
  return {
251
+ "status": "ok" if tts_engine is not None else "model_loading_failed",
252
+ "message": "Nigerian TTS API is running",
 
253
  "available_languages": AVAILABLE_LANGUAGES,
254
+ "available_voices": AVAILABLE_VOICES,
255
+ "accent_mapping": ACCENT_TO_VOICE
256
  }
257
 
258
+ @app.post("/tts", response_model=TTSResponse)
259
  async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
260
+ """Generate speech from text."""
261
+ if tts_engine is None:
262
+ raise HTTPException(
263
+ status_code=503,
264
+ detail="TTS engine is not initialized"
265
+ )
266
+
267
+ # Validate and process parameters
268
+ voice = request.voice or ACCENT_TO_VOICE.get(request.accent.lower(), "tayo")
269
+ language = request.language.lower()
270
+
271
+ if language not in AVAILABLE_LANGUAGES:
272
+ raise HTTPException(
273
+ status_code=400,
274
+ detail=f"Language must be one of {AVAILABLE_LANGUAGES}"
275
+ )
276
+
277
+ all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
278
+ if voice not in all_voices:
279
+ raise HTTPException(
280
+ status_code=400,
281
+ detail=f"Voice must be one of {all_voices}"
282
+ )
283
 
284
  try:
285
+ # Generate unique filename
286
  audio_id = str(uuid.uuid4())
287
+ output_path = f"audio_files/{audio_id}.wav"
288
 
289
+ # Generate audio
290
+ audio = tts_engine.generate_speech(
291
+ text=request.text,
292
+ language=language,
293
+ speaker_name=voice,
294
+ speed=request.speed
295
+ )
296
 
297
+ # Save audio file
298
+ torchaudio.save(output_path, audio, sample_rate=24000)
 
299
 
300
+ # Generate base64 representation
301
+ import base64
302
+ with open(output_path, "rb") as audio_file:
303
+ audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ # Schedule cleanup
306
  background_tasks.add_task(cleanup_old_files)
307
 
 
 
 
 
 
 
 
 
308
  return {
 
309
  "audio_url": f"/audio/{audio_id}.wav",
310
+ "audio_base64": audio_base64,
311
  "text": request.text,
312
+ "voice": voice,
313
+ "language": language
314
  }
315
 
316
  except Exception as e:
317
+ logger.error(f"Error generating audio: {str(e)}")
318
+ logger.error(traceback.format_exc())
319
+ raise HTTPException(
320
+ status_code=500,
321
+ detail=f"Error generating audio: {str(e)}"
322
+ )
323
 
 
324
  @app.get("/audio/{filename}")
325
  async def get_audio(filename: str):
326
+ """Serve generated audio files."""
327
+ file_path = f"audio_files/{filename}"
328
  if not os.path.exists(file_path):
329
+ raise HTTPException(
330
+ status_code=404,
331
+ detail="Audio file not found"
332
+ )
333
  return FileResponse(file_path, media_type="audio/wav")
334
 
335
+ @app.exception_handler(Exception)
336
+ async def global_exception_handler(request: Request, exc: Exception):
337
+ """Global exception handler."""
338
+ logger.error(f"Unhandled exception: {str(exc)}")
339
+ logger.error(traceback.format_exc())
340
+ return JSONResponse(
341
+ status_code=500,
342
+ content={"detail": f"An unexpected error occurred: {str(exc)}"}
343
+ )
344
+
345
  if __name__ == "__main__":
346
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)