Hameed13 commited on
Commit
6b625c9
·
verified ·
1 Parent(s): ea51340

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +106 -253
main.py CHANGED
@@ -1,43 +1,31 @@
 
 
 
 
1
  import os
2
  import sys
3
  import time
4
- import uuid
5
- import logging
6
- import traceback
7
- import requests
8
- from pathlib import Path
9
- from datetime import datetime, timedelta
10
- from typing import Optional
11
  import torch
12
  import torchaudio
13
-
14
- from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
15
- from fastapi.responses import FileResponse, JSONResponse
16
- from fastapi.middleware.cors import CORSMiddleware
17
- from pydantic import BaseModel
18
- import uvicorn
19
  from huggingface_hub import hf_hub_download
 
 
 
 
20
 
21
  # Configure logging
22
  logging.basicConfig(
23
  level=logging.INFO,
24
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
25
- handlers=[
26
- logging.StreamHandler(),
27
- logging.FileHandler("app.log")
28
- ]
29
  )
30
  logger = logging.getLogger(__name__)
31
 
32
- # Create necessary directories
33
- REQUIRED_DIRS = ["audio_files", "models", "saheedniyi_YarnGPT2"]
34
- for directory in REQUIRED_DIRS:
35
- os.makedirs(directory, exist_ok=True)
36
-
37
  # Initialize FastAPI app
38
  app = FastAPI(
39
- title="Nigerian Text-to-Speech API",
40
- description="Convert text to Nigerian-accented speech using YarnGPT",
41
  version="1.0.0"
42
  )
43
 
@@ -50,232 +38,111 @@ app.add_middleware(
50
  allow_headers=["*"],
51
  )
52
 
53
- # Input validation models
54
- class TTSRequest(BaseModel):
55
- text: str
56
- accent: str = "nigerian"
57
- voice: str = None
58
- language: str = "english"
59
- speed: float = 1.0
60
-
61
- class Config:
62
- schema_extra = {
63
- "example": {
64
- "text": "Welcome to Nigeria, the giant of Africa.",
65
- "accent": "nigerian",
66
- "voice": "tayo",
67
- "language": "english",
68
- "speed": 1.0
69
- }
70
- }
71
 
72
- class TTSResponse(BaseModel):
73
- audio_url: str
74
- audio_base64: str = None
75
- text: str
76
- voice: str
77
- language: str
78
-
79
- # Define available voices and languages
80
  AVAILABLE_VOICES = {
81
  "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
82
  "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
83
  }
84
-
85
- ACCENT_TO_VOICE = {
86
- "nigerian": "tayo",
87
- "yoruba": "idera",
88
- "igbo": "emma",
89
- "hausa": "umar"
90
- }
91
-
92
  AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
93
 
94
- # Initialize global variables for model components
95
- model = None
96
- audio_tokenizer = None
97
- tts_engine = None
98
-
99
- def download_model_files():
100
- """Download required model files from Hugging Face Hub."""
101
- files_to_download = [
102
- {
103
- "repo_id": "novateur/WavTokenizer-small-speech-320token",
104
- "filename": "wavtokenizer_large_speech_320_24k.ckpt",
105
- "output_path": "models/wavtokenizer_large_speech_320_24k.ckpt"
106
- },
107
- {
108
- "repo_id": "saheedniyi/YarnGPT2",
109
- "filename": "config.json",
110
- "output_path": "saheedniyi_YarnGPT2/config.json"
111
- },
112
- {
113
- "repo_id": "saheedniyi/YarnGPT2",
114
- "filename": "tokenizer_config.json",
115
- "output_path": "saheedniyi_YarnGPT2/tokenizer_config.json"
116
- },
117
- {
118
- "repo_id": "saheedniyi/YarnGPT2",
119
- "filename": "pytorch_model.bin",
120
- "output_path": "saheedniyi_YarnGPT2/pytorch_model.bin"
121
- }
122
- ]
123
-
124
- for file_info in files_to_download:
125
- try:
126
- if not os.path.exists(file_info["output_path"]):
127
- logger.info(f"Downloading {file_info['filename']} from {file_info['repo_id']}")
128
- hf_hub_download(
129
- repo_id=file_info["repo_id"],
130
- filename=file_info["filename"],
131
- local_dir=".",
132
- local_dir_use_symlinks=False
133
- )
134
- logger.info(f"Successfully downloaded {file_info['filename']}")
135
- else:
136
- logger.info(f"File already exists: {file_info['output_path']}")
137
- except Exception as e:
138
- logger.error(f"Error downloading {file_info['filename']}: {str(e)}")
139
- raise
140
-
141
- def load_tts_engine():
142
- """Initialize the TTS engine with proper error handling."""
143
- global model, audio_tokenizer, tts_engine
144
-
145
  try:
146
- # Import required modules
147
- from transformers import AutoModelForCausalLM
148
- from yarngpt.audiotokenizer import AudioTokenizerV2
149
 
150
- # Set device
151
- device = "cuda" if torch.cuda.is_available() else "cpu"
152
- logger.info(f"Using device: {device}")
 
 
 
 
 
 
153
 
154
- # Load tokenizer and model
155
- tokenizer_path = "saheedniyi_YarnGPT2"
156
- wav_tokenizer_config = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
157
- wav_tokenizer_model = "models/wavtokenizer_large_speech_320_24k.ckpt"
158
 
159
- logger.info("Loading audio tokenizer...")
160
  audio_tokenizer = AudioTokenizerV2(
161
- tokenizer_path,
162
  wav_tokenizer_model,
163
  wav_tokenizer_config
164
  )
165
 
166
- logger.info("Loading model...")
167
  model = AutoModelForCausalLM.from_pretrained(
168
- tokenizer_path,
169
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
170
- ).to(device)
171
 
172
- class TextToSpeech:
173
- def __init__(self):
174
- self.audio_tokenizer = audio_tokenizer
175
- self.model = model
176
- self.device = device
177
-
178
- def generate_speech(self, text, language="english", speaker_name="tayo", speed=1.0):
179
- prompt = self.audio_tokenizer.create_prompt(
180
- text,
181
- lang=language,
182
- speaker_name=speaker_name
183
- )
184
- input_ids = self.audio_tokenizer.tokenize_prompt(prompt)
185
-
186
- with torch.no_grad():
187
- output = self.model.generate(
188
- input_ids=input_ids,
189
- temperature=0.1,
190
- repetition_penalty=1.1,
191
- max_length=4000,
192
- do_sample=True,
193
- top_k=50,
194
- top_p=0.95
195
- )
196
-
197
- codes = self.audio_tokenizer.get_codes(output)
198
- audio = self.audio_tokenizer.get_audio(codes)
199
-
200
- if speed != 1.0:
201
- import librosa
202
- audio = librosa.effects.time_stretch(audio.numpy().squeeze(), rate=speed)
203
- audio = torch.from_numpy(audio).unsqueeze(0)
204
-
205
- return audio
206
-
207
- tts_engine = TextToSpeech()
208
- logger.info("TTS engine initialized successfully!")
209
- return True
210
 
211
  except Exception as e:
212
- logger.error(f"Error initializing TTS engine: {str(e)}")
213
- logger.error(traceback.format_exc())
214
- return False
215
 
216
- def cleanup_old_files(max_age_hours=6):
217
- """Remove audio files older than the specified hours."""
218
- try:
219
- now = datetime.now()
220
- for filename in os.listdir("audio_files"):
221
- if filename.endswith(".wav"):
222
- file_path = os.path.join("audio_files", filename)
223
- if now - datetime.fromtimestamp(os.path.getmtime(file_path)) > timedelta(hours=max_age_hours):
224
- os.remove(file_path)
225
- logger.info(f"Deleted old audio file: {filename}")
226
- except Exception as e:
227
- logger.error(f"Error cleaning up files: {str(e)}")
228
 
229
- @app.on_event("startup")
230
- async def startup_event():
231
- """Initialize the application on startup."""
 
 
 
 
 
 
232
  try:
233
- # Download model files
234
- download_model_files()
235
-
236
- # Initialize TTS engine
237
- success = load_tts_engine()
238
- if not success:
239
- logger.error("Failed to initialize TTS engine")
240
- raise RuntimeError("TTS engine initialization failed")
241
 
 
 
 
 
 
 
242
  except Exception as e:
243
- logger.error(f"Startup failed: {str(e)}")
244
- logger.error(traceback.format_exc())
245
- raise
246
 
 
247
  @app.get("/")
248
  async def root():
249
- """API health check and info endpoint."""
250
  return {
251
- "status": "ok" if tts_engine is not None else "model_loading_failed",
252
- "message": "Nigerian TTS API is running",
253
  "available_languages": AVAILABLE_LANGUAGES,
254
- "available_voices": AVAILABLE_VOICES,
255
- "accent_mapping": ACCENT_TO_VOICE
256
  }
257
 
258
  @app.post("/tts", response_model=TTSResponse)
259
  async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
260
- """Generate speech from text."""
261
- if tts_engine is None:
262
- raise HTTPException(
263
- status_code=503,
264
- detail="TTS engine is not initialized"
265
- )
266
 
267
- # Validate and process parameters
268
- voice = request.voice or ACCENT_TO_VOICE.get(request.accent.lower(), "tayo")
269
- language = request.language.lower()
270
-
271
- if language not in AVAILABLE_LANGUAGES:
272
  raise HTTPException(
273
  status_code=400,
274
  detail=f"Language must be one of {AVAILABLE_LANGUAGES}"
275
  )
276
 
277
  all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
278
- if voice not in all_voices:
279
  raise HTTPException(
280
  status_code=400,
281
  detail=f"Voice must be one of {all_voices}"
@@ -284,63 +151,49 @@ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks)
284
  try:
285
  # Generate unique filename
286
  audio_id = str(uuid.uuid4())
287
- output_path = f"audio_files/{audio_id}.wav"
288
 
289
  # Generate audio
290
- audio = tts_engine.generate_speech(
291
- text=request.text,
292
- language=language,
293
- speaker_name=voice,
294
- speed=request.speed
295
  )
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  # Save audio file
298
  torchaudio.save(output_path, audio, sample_rate=24000)
299
 
300
- # Generate base64 representation
301
- import base64
302
  with open(output_path, "rb") as audio_file:
303
- audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
 
304
 
305
  # Schedule cleanup
306
  background_tasks.add_task(cleanup_old_files)
307
 
308
- return {
309
- "audio_url": f"/audio/{audio_id}.wav",
310
- "audio_base64": audio_base64,
311
- "text": request.text,
312
- "voice": voice,
313
- "language": language
314
- }
315
 
316
  except Exception as e:
317
  logger.error(f"Error generating audio: {str(e)}")
318
- logger.error(traceback.format_exc())
319
- raise HTTPException(
320
- status_code=500,
321
- detail=f"Error generating audio: {str(e)}"
322
- )
323
-
324
- @app.get("/audio/{filename}")
325
- async def get_audio(filename: str):
326
- """Serve generated audio files."""
327
- file_path = f"audio_files/{filename}"
328
- if not os.path.exists(file_path):
329
- raise HTTPException(
330
- status_code=404,
331
- detail="Audio file not found"
332
- )
333
- return FileResponse(file_path, media_type="audio/wav")
334
-
335
- @app.exception_handler(Exception)
336
- async def global_exception_handler(request: Request, exc: Exception):
337
- """Global exception handler."""
338
- logger.error(f"Unhandled exception: {str(exc)}")
339
- logger.error(traceback.format_exc())
340
- return JSONResponse(
341
- status_code=500,
342
- content={"detail": f"An unexpected error occurred: {str(exc)}"}
343
- )
344
 
345
  if __name__ == "__main__":
346
- uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
+ from fastapi.responses import FileResponse, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
  import os
6
  import sys
7
  import time
 
 
 
 
 
 
 
8
  import torch
9
  import torchaudio
10
+ import base64
11
+ from transformers import AutoModelForCausalLM
 
 
 
 
12
  from huggingface_hub import hf_hub_download
13
+ import logging
14
+ from datetime import datetime, timedelta
15
+ import uuid
16
+ from typing import Optional
17
 
18
  # Configure logging
19
  logging.basicConfig(
20
  level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
 
 
 
22
  )
23
  logger = logging.getLogger(__name__)
24
 
 
 
 
 
 
25
  # Initialize FastAPI app
26
  app = FastAPI(
27
+ title="Nigerian TTS API",
28
+ description="API for Nigerian Text-to-Speech using YarnGPT",
29
  version="1.0.0"
30
  )
31
 
 
38
  allow_headers=["*"],
39
  )
40
 
41
+ # Constants
42
+ MODEL_ID = "saheedniyi/YarnGPT2"
43
+ AUDIO_DIR = "audio_files"
44
+ os.makedirs(AUDIO_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Available voices and languages
 
 
 
 
 
 
 
47
  AVAILABLE_VOICES = {
48
  "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
49
  "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
50
  }
 
 
 
 
 
 
 
 
51
  AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
52
 
53
+ # Model initialization
54
+ def initialize_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
+ logger.info("Loading YarnGPT model and tokenizer...")
 
 
57
 
58
+ # Download necessary files from HuggingFace Hub
59
+ wav_tokenizer_config = hf_hub_download(
60
+ repo_id=MODEL_ID,
61
+ filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
62
+ )
63
+ wav_tokenizer_model = hf_hub_download(
64
+ repo_id=MODEL_ID,
65
+ filename="wavtokenizer_large_speech_320_24k.ckpt"
66
+ )
67
 
68
+ # Import AudioTokenizer here to ensure files are downloaded first
69
+ from yarngpt.audiotokenizer import AudioTokenizerV2
 
 
70
 
 
71
  audio_tokenizer = AudioTokenizerV2(
72
+ MODEL_ID,
73
  wav_tokenizer_model,
74
  wav_tokenizer_config
75
  )
76
 
 
77
  model = AutoModelForCausalLM.from_pretrained(
78
+ MODEL_ID,
79
+ torch_dtype="auto"
80
+ ).to(audio_tokenizer.device)
81
 
82
+ logger.info("Model loaded successfully!")
83
+ return audio_tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  except Exception as e:
86
+ logger.error(f"Error initializing model: {str(e)}")
87
+ raise
 
88
 
89
+ # Initialize model at startup
90
+ audio_tokenizer, model = initialize_model()
91
+
92
+ # Pydantic models
93
+ class TTSRequest(BaseModel):
94
+ text: str
95
+ language: str = "english"
96
+ voice: str = "idera"
 
 
 
 
97
 
98
+ class TTSResponse(BaseModel):
99
+ audio_base64: str
100
+ text: str
101
+ voice: str
102
+ language: str
103
+
104
+ # Cleanup function
105
+ def cleanup_old_files(max_age_hours: int = 6):
106
+ """Delete audio files older than specified hours"""
107
  try:
108
+ now = datetime.now()
109
+ for filename in os.listdir(AUDIO_DIR):
110
+ if not filename.endswith(".wav"):
111
+ continue
 
 
 
 
112
 
113
+ file_path = os.path.join(AUDIO_DIR, filename)
114
+ file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
115
+
116
+ if now - file_mod_time > timedelta(hours=max_age_hours):
117
+ os.remove(file_path)
118
+ logger.info(f"Deleted old audio file: {filename}")
119
  except Exception as e:
120
+ logger.error(f"Error cleaning up files: {str(e)}")
 
 
121
 
122
+ # API endpoints
123
  @app.get("/")
124
  async def root():
125
+ """Health check endpoint"""
126
  return {
127
+ "status": "healthy",
128
+ "model": MODEL_ID,
129
  "available_languages": AVAILABLE_LANGUAGES,
130
+ "available_voices": AVAILABLE_VOICES
 
131
  }
132
 
133
  @app.post("/tts", response_model=TTSResponse)
134
  async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
135
+ """Generate Nigerian-accented speech from text"""
 
 
 
 
 
136
 
137
+ # Validate inputs
138
+ if request.language not in AVAILABLE_LANGUAGES:
 
 
 
139
  raise HTTPException(
140
  status_code=400,
141
  detail=f"Language must be one of {AVAILABLE_LANGUAGES}"
142
  )
143
 
144
  all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
145
+ if request.voice not in all_voices:
146
  raise HTTPException(
147
  status_code=400,
148
  detail=f"Voice must be one of {all_voices}"
 
151
  try:
152
  # Generate unique filename
153
  audio_id = str(uuid.uuid4())
154
+ output_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav")
155
 
156
  # Generate audio
157
+ prompt = audio_tokenizer.create_prompt(
158
+ request.text,
159
+ lang=request.language,
160
+ speaker_name=request.voice
 
161
  )
162
+ input_ids = audio_tokenizer.tokenize_prompt(prompt)
163
+
164
+ with torch.no_grad():
165
+ output = model.generate(
166
+ input_ids=input_ids,
167
+ temperature=0.1,
168
+ repetition_penalty=1.1,
169
+ max_length=4000,
170
+ )
171
+
172
+ codes = audio_tokenizer.get_codes(output)
173
+ audio = audio_tokenizer.get_audio(codes)
174
 
175
  # Save audio file
176
  torchaudio.save(output_path, audio, sample_rate=24000)
177
 
178
+ # Read and encode as base64
 
179
  with open(output_path, "rb") as audio_file:
180
+ audio_bytes = audio_file.read()
181
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
182
 
183
  # Schedule cleanup
184
  background_tasks.add_task(cleanup_old_files)
185
 
186
+ return TTSResponse(
187
+ audio_base64=audio_base64,
188
+ text=request.text,
189
+ voice=request.voice,
190
+ language=request.language
191
+ )
 
192
 
193
  except Exception as e:
194
  logger.error(f"Error generating audio: {str(e)}")
195
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  if __name__ == "__main__":
198
+ import uvicorn
199
+ uvicorn.run(app, host="0.0.0.0", port=7860)