Hameed13 commited on
Commit
809bcb3
·
verified ·
1 Parent(s): 6b625c9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +92 -132
main.py CHANGED
@@ -1,47 +1,40 @@
1
  from fastapi import FastAPI, HTTPException, BackgroundTasks
2
- from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  import os
6
- import sys
7
- import time
8
  import torch
9
  import torchaudio
10
  import base64
11
  from transformers import AutoModelForCausalLM
12
- from huggingface_hub import hf_hub_download
13
- import logging
14
  from datetime import datetime, timedelta
15
- import uuid
16
- from typing import Optional
17
 
18
- # Configure logging
19
- logging.basicConfig(
20
- level=logging.INFO,
21
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
- )
23
- logger = logging.getLogger(__name__)
24
 
25
- # Initialize FastAPI app
26
- app = FastAPI(
27
- title="Nigerian TTS API",
28
- description="API for Nigerian Text-to-Speech using YarnGPT",
29
- version="1.0.0"
30
- )
31
-
32
- # Add CORS middleware
33
  app.add_middleware(
34
  CORSMiddleware,
35
- allow_origins=["*"],
36
  allow_credentials=True,
37
- allow_methods=["*"],
38
- allow_headers=["*"],
39
  )
40
 
41
- # Constants
42
- MODEL_ID = "saheedniyi/YarnGPT2"
43
- AUDIO_DIR = "audio_files"
44
- os.makedirs(AUDIO_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
45
 
46
  # Available voices and languages
47
  AVAILABLE_VOICES = {
@@ -50,150 +43,117 @@ AVAILABLE_VOICES = {
50
  }
51
  AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
52
 
53
- # Model initialization
54
- def initialize_model():
55
- try:
56
- logger.info("Loading YarnGPT model and tokenizer...")
57
-
58
- # Download necessary files from HuggingFace Hub
59
- wav_tokenizer_config = hf_hub_download(
60
- repo_id=MODEL_ID,
61
- filename="wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
62
- )
63
- wav_tokenizer_model = hf_hub_download(
64
- repo_id=MODEL_ID,
65
- filename="wavtokenizer_large_speech_320_24k.ckpt"
66
- )
67
-
68
- # Import AudioTokenizer here to ensure files are downloaded first
69
- from yarngpt.audiotokenizer import AudioTokenizerV2
70
-
71
- audio_tokenizer = AudioTokenizerV2(
72
- MODEL_ID,
73
- wav_tokenizer_model,
74
- wav_tokenizer_config
75
- )
76
-
77
- model = AutoModelForCausalLM.from_pretrained(
78
- MODEL_ID,
79
- torch_dtype="auto"
80
- ).to(audio_tokenizer.device)
81
-
82
- logger.info("Model loaded successfully!")
83
- return audio_tokenizer, model
84
-
85
- except Exception as e:
86
- logger.error(f"Error initializing model: {str(e)}")
87
- raise
88
-
89
- # Initialize model at startup
90
- audio_tokenizer, model = initialize_model()
91
-
92
- # Pydantic models
93
  class TTSRequest(BaseModel):
94
  text: str
95
  language: str = "english"
96
  voice: str = "idera"
97
 
 
98
  class TTSResponse(BaseModel):
99
- audio_base64: str
 
100
  text: str
101
  voice: str
102
  language: str
103
 
104
- # Cleanup function
105
- def cleanup_old_files(max_age_hours: int = 6):
106
- """Delete audio files older than specified hours"""
107
- try:
108
- now = datetime.now()
109
- for filename in os.listdir(AUDIO_DIR):
110
- if not filename.endswith(".wav"):
111
- continue
112
-
113
- file_path = os.path.join(AUDIO_DIR, filename)
114
- file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
115
-
116
- if now - file_mod_time > timedelta(hours=max_age_hours):
117
- os.remove(file_path)
118
- logger.info(f"Deleted old audio file: {filename}")
119
- except Exception as e:
120
- logger.error(f"Error cleaning up files: {str(e)}")
121
-
122
- # API endpoints
123
  @app.get("/")
124
  async def root():
125
- """Health check endpoint"""
126
  return {
127
- "status": "healthy",
128
- "model": MODEL_ID,
129
  "available_languages": AVAILABLE_LANGUAGES,
130
  "available_voices": AVAILABLE_VOICES
131
  }
132
 
133
  @app.post("/tts", response_model=TTSResponse)
134
  async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
135
- """Generate Nigerian-accented speech from text"""
136
-
137
  # Validate inputs
138
  if request.language not in AVAILABLE_LANGUAGES:
139
- raise HTTPException(
140
- status_code=400,
141
- detail=f"Language must be one of {AVAILABLE_LANGUAGES}"
142
- )
143
-
144
  all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
145
  if request.voice not in all_voices:
146
- raise HTTPException(
147
- status_code=400,
148
- detail=f"Voice must be one of {all_voices}"
149
- )
150
-
 
 
151
  try:
152
- # Generate unique filename
153
- audio_id = str(uuid.uuid4())
154
- output_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav")
155
-
156
- # Generate audio
157
- prompt = audio_tokenizer.create_prompt(
158
- request.text,
159
- lang=request.language,
160
- speaker_name=request.voice
161
- )
162
  input_ids = audio_tokenizer.tokenize_prompt(prompt)
163
-
164
- with torch.no_grad():
165
- output = model.generate(
166
- input_ids=input_ids,
167
- temperature=0.1,
168
- repetition_penalty=1.1,
169
- max_length=4000,
170
- )
171
-
172
  codes = audio_tokenizer.get_codes(output)
173
  audio = audio_tokenizer.get_audio(codes)
174
-
175
  # Save audio file
176
  torchaudio.save(output_path, audio, sample_rate=24000)
177
-
178
- # Read and encode as base64
179
  with open(output_path, "rb") as audio_file:
180
  audio_bytes = audio_file.read()
181
  audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
182
-
183
- # Schedule cleanup
184
  background_tasks.add_task(cleanup_old_files)
185
-
186
  return TTSResponse(
187
  audio_base64=audio_base64,
 
188
  text=request.text,
189
  voice=request.voice,
190
  language=request.language
191
  )
192
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  except Exception as e:
194
- logger.error(f"Error generating audio: {str(e)}")
195
- raise HTTPException(status_code=500, detail=str(e))
196
 
 
197
  if __name__ == "__main__":
198
- import uvicorn
199
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, HTTPException, BackgroundTasks
2
+ from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  import os
6
+ import uuid
 
7
  import torch
8
  import torchaudio
9
  import base64
10
  from transformers import AutoModelForCausalLM
11
+ from yarngpt.audiotokenizer import AudioTokenizerV2
12
+ import uvicorn
13
  from datetime import datetime, timedelta
 
 
14
 
15
+ app = FastAPI(title="Nigerian TTS API")
 
 
 
 
 
16
 
17
+ # Add CORS middleware to allow requests from any origin
 
 
 
 
 
 
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
+ allow_origins=["*"], # Allows all origins
21
  allow_credentials=True,
22
+ allow_methods=["*"], # Allows all methods
23
+ allow_headers=["*"], # Allows all headers
24
  )
25
 
26
+ # Model configuration paths
27
+ tokenizer_path = "saheedniyi/YarnGPT2"
28
+ wav_tokenizer_config_path = "./wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
29
+ wav_tokenizer_model_path = "./wavtokenizer_large_speech_320_24k.ckpt"
30
+
31
+ # Initialize model (only once when the API starts)
32
+ print("Loading YarnGPT model and tokenizer...")
33
+ audio_tokenizer = AudioTokenizerV2(
34
+ tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
35
+ )
36
+ model = AutoModelForCausalLM.from_pretrained(tokenizer_path, torch_dtype="auto").to(audio_tokenizer.device)
37
+ print("Model loaded successfully!")
38
 
39
  # Available voices and languages
40
  AVAILABLE_VOICES = {
 
43
  }
44
  AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
45
 
46
+ # Input validation model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class TTSRequest(BaseModel):
48
  text: str
49
  language: str = "english"
50
  voice: str = "idera"
51
 
52
+ # Output model with base64-encoded audio
53
  class TTSResponse(BaseModel):
54
+ audio_base64: str # Base64-encoded audio data
55
+ audio_url: str # Keep for backward compatibility
56
  text: str
57
  voice: str
58
  language: str
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @app.get("/")
61
  async def root():
62
+ """API health check and info"""
63
  return {
64
+ "status": "ok",
65
+ "message": "Nigerian TTS API is running",
66
  "available_languages": AVAILABLE_LANGUAGES,
67
  "available_voices": AVAILABLE_VOICES
68
  }
69
 
70
  @app.post("/tts", response_model=TTSResponse)
71
  async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
72
+ """Convert text to Nigerian-accented speech"""
73
+
74
  # Validate inputs
75
  if request.language not in AVAILABLE_LANGUAGES:
76
+ raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}")
77
+
 
 
 
78
  all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
79
  if request.voice not in all_voices:
80
+ raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}")
81
+
82
+ # Generate unique filename
83
+ audio_id = str(uuid.uuid4())
84
+ output_path = f"audio_files/{audio_id}.wav"
85
+ os.makedirs("audio_files", exist_ok=True)
86
+
87
  try:
88
+ # Create prompt and generate audio
89
+ prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice)
 
 
 
 
 
 
 
 
90
  input_ids = audio_tokenizer.tokenize_prompt(prompt)
91
+
92
+ output = model.generate(
93
+ input_ids=input_ids,
94
+ temperature=0.1,
95
+ repetition_penalty=1.1,
96
+ max_length=4000,
97
+ )
98
+
 
99
  codes = audio_tokenizer.get_codes(output)
100
  audio = audio_tokenizer.get_audio(codes)
101
+
102
  # Save audio file
103
  torchaudio.save(output_path, audio, sample_rate=24000)
104
+
105
+ # Read the file and encode as base64
106
  with open(output_path, "rb") as audio_file:
107
  audio_bytes = audio_file.read()
108
  audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
109
+
110
+ # Clean up old files after a while
111
  background_tasks.add_task(cleanup_old_files)
112
+
113
  return TTSResponse(
114
  audio_base64=audio_base64,
115
+ audio_url=f"/audio/{audio_id}.wav", # Keep for compatibility
116
  text=request.text,
117
  voice=request.voice,
118
  language=request.language
119
  )
120
+
121
+ except Exception as e:
122
+ raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
123
+
124
+ # File serving endpoint (keep for backward compatibility)
125
+ @app.get("/audio/{filename}")
126
+ async def get_audio(filename: str):
127
+ file_path = f"audio_files/{filename}"
128
+ if not os.path.exists(file_path):
129
+ raise HTTPException(status_code=404, detail="Audio file not found")
130
+ return FileResponse(file_path, media_type="audio/wav")
131
+
132
+ # Cleanup function to remove old files
133
+ def cleanup_old_files():
134
+ """Delete audio files older than 6 hours to manage disk space"""
135
+ try:
136
+ now = datetime.now()
137
+ audio_dir = "audio_files"
138
+
139
+ if not os.path.exists(audio_dir):
140
+ return
141
+
142
+ for filename in os.listdir(audio_dir):
143
+ if not filename.endswith(".wav"):
144
+ continue
145
+
146
+ file_path = os.path.join(audio_dir, filename)
147
+ file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
148
+
149
+ # Delete files older than 6 hours
150
+ if now - file_mod_time > timedelta(hours=6):
151
+ os.remove(file_path)
152
+ print(f"Deleted old audio file: {filename}")
153
  except Exception as e:
154
+ print(f"Error cleaning up old files: {e}")
 
155
 
156
+ # For Hugging Face Spaces, we'll use the default port 7860
157
  if __name__ == "__main__":
158
+ print("Starting Nigerian TTS API server...")
159
  uvicorn.run(app, host="0.0.0.0", port=7860)