Hameed13 commited on
Commit
71f917b
·
verified ·
1 Parent(s): 109c3b2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +135 -173
main.py CHANGED
@@ -1,105 +1,29 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
- import sys
6
- import time
7
- import numpy as np
8
  import torch
 
 
 
 
 
9
  import logging
10
- from pathlib import Path
11
- from huggingface_hub import hf_hub_download
12
- from datetime import datetime
13
 
14
  # Setup logging
15
  logging.basicConfig(level=logging.INFO,
16
  format='%(asctime)s - %(levelname)s - %(message)s')
17
  logger = logging.getLogger(__name__)
18
 
19
- # Current timestamp and user
20
- CURRENT_TIMESTAMP = "2025-05-21 01:58:55"
21
  CURRENT_USER = "Abdulhameed556"
22
 
23
- # Set cache directory to a writable location
24
- os.environ['HF_HOME'] = '/code/cache'
25
- os.environ['TRANSFORMERS_CACHE'] = '/code/cache/transformers'
26
-
27
- # Define all required directories
28
- CACHE_DIR = '/code/cache'
29
- MODEL_DIR = '/code/models'
30
- AUDIO_DIR = '/code/audio_files'
31
-
32
- # Create directories if they don't exist
33
- for directory in [CACHE_DIR, MODEL_DIR, AUDIO_DIR]:
34
- os.makedirs(directory, exist_ok=True)
35
-
36
- # Model configuration
37
- MODEL_REPO = "Hameed13/News_Podcast_Model"
38
- MODEL_FILENAME = "model.ckpt"
39
-
40
- # Model config
41
- MODEL_CONFIG = {
42
- "model_type": "speech_to_text_2",
43
- "architectures": ["Speech2Text2ForConditionalGeneration"],
44
- "activation_dropout": 0.1,
45
- "activation_function": "relu",
46
- "attention_dropout": 0.1,
47
- "d_model": 512,
48
- "decoder_attention_heads": 8,
49
- "decoder_ffn_dim": 2048,
50
- "decoder_layers": 6,
51
- "dropout": 0.1,
52
- "encoder_attention_heads": 8,
53
- "encoder_ffn_dim": 2048,
54
- "encoder_layers": 6,
55
- "init_std": 0.02,
56
- "max_speech_positions": 4000,
57
- "max_text_positions": 1024,
58
- "num_conv_layers": 2,
59
- "num_hidden_layers": 12,
60
- "speech_vocab_size": 4096,
61
- "vocab_size": 50265,
62
- "use_cache": True,
63
- "tie_word_embeddings": True,
64
- "is_encoder_decoder": True,
65
- "pad_token_id": 1,
66
- "bos_token_id": 0,
67
- "eos_token_id": 2,
68
- "_name_or_path": MODEL_REPO,
69
- "model_creation_date": CURRENT_TIMESTAMP,
70
- "model_creator": CURRENT_USER
71
- }
72
-
73
- # Download model from Hub
74
- try:
75
- logger.info("Downloading model from Hugging Face Hub")
76
- MODEL_CHECKPOINT = hf_hub_download(
77
- repo_id=MODEL_REPO,
78
- filename=MODEL_FILENAME,
79
- token=os.getenv('HF_TOKEN'),
80
- cache_dir=CACHE_DIR
81
- )
82
- logger.info(f"Model downloaded successfully to: {MODEL_CHECKPOINT}")
83
- except Exception as e:
84
- logger.error(f"Failed to download model: {e}")
85
- raise
86
-
87
- # Import TTS modules
88
- try:
89
- from yarngpt.generate import TextToSpeech
90
- logger.info("Successfully imported yarngpt modules")
91
- except ImportError as e:
92
- logger.error(f"Failed to import YarnGPT modules: {e}")
93
- raise
94
-
95
- # Create FastAPI app
96
- app = FastAPI(
97
- title="Nigerian Text-to-Speech API",
98
- description="A text-to-speech API for Nigerian English",
99
- version="1.0.0"
100
- )
101
 
102
- # Configure CORS
103
  app.add_middleware(
104
  CORSMiddleware,
105
  allow_origins=["*"],
@@ -108,107 +32,145 @@ app.add_middleware(
108
  allow_headers=["*"],
109
  )
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class TTSRequest(BaseModel):
112
  text: str
113
- temperature: float = 0.2
114
- top_p: float = 0.7
115
- top_k: int = 50
116
- speed: float = 1.0
117
- seed: int = 42
118
 
119
  class TTSResponse(BaseModel):
120
- audio: str
121
- generation_time: float
122
- timestamp: str
123
- user: str
 
124
 
125
  @app.get("/")
126
- def read_root():
127
- """Root endpoint returning API status"""
128
  return {
129
- "message": "Nigerian Text-to-Speech API is running",
130
- "status": "active",
 
 
 
131
  "timestamp": CURRENT_TIMESTAMP,
132
- "user": CURRENT_USER,
133
- "model": MODEL_REPO,
134
- "version": "1.0.0"
135
  }
136
 
137
  @app.post("/tts", response_model=TTSResponse)
138
- async def text_to_speech(request: TTSRequest):
139
- """Generate speech from text"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  try:
141
- logger.info(f"Processing TTS request: {request.text[:50]}...")
142
-
143
- # Set random seed if provided
144
- if request.seed is not None:
145
- torch.manual_seed(request.seed)
146
- np.random.seed(request.seed)
147
-
148
- # Initialize TTS with model
149
- tts = TextToSpeech(
150
- MODEL_REPO,
151
- processor_name_or_path=MODEL_REPO
152
- )
153
-
154
- # Generate audio
155
- start_time = time.time()
156
- audio = tts.tts(
157
- request.text,
158
- speed=request.speed
159
  )
160
-
161
- # Convert audio to base64
162
- import base64
163
- import io
164
- audio_io = io.BytesIO()
165
- import scipy.io.wavfile as wav
166
- wav.write(audio_io, 24000, audio)
167
- audio_io.seek(0)
168
- audio_base64 = base64.b64encode(audio_io.read()).decode('utf-8')
169
-
170
- generation_time = time.time() - start_time
171
- logger.info(f"Generated audio in {generation_time:.2f} seconds")
172
-
 
 
173
  return TTSResponse(
174
- audio=audio_base64,
175
- generation_time=generation_time,
176
- timestamp=CURRENT_TIMESTAMP,
177
- user=CURRENT_USER
 
178
  )
179
-
180
  except Exception as e:
181
- logger.error(f"Error generating speech: {str(e)}", exc_info=True)
182
- raise HTTPException(
183
- status_code=500,
184
- detail=f"Error generating speech: {str(e)}"
185
- )
 
 
 
 
 
 
 
 
 
 
186
 
187
- @app.get("/health")
188
- def health_check():
189
- """Health check endpoint"""
190
- return {
191
- "status": "ok",
192
- "models_loaded": True,
193
- "timestamp": CURRENT_TIMESTAMP,
194
- "user": CURRENT_USER,
195
- "model": {
196
- "repo": MODEL_REPO,
197
- "checkpoint": MODEL_CHECKPOINT,
198
- "cache_dir": CACHE_DIR
199
- },
200
- "system": {
201
- "cuda_available": torch.cuda.is_available(),
202
- "device": "cuda" if torch.cuda.is_available() else "cpu",
203
- "python_version": sys.version
204
- }
205
- }
206
 
207
  if __name__ == "__main__":
208
  import uvicorn
209
- uvicorn.run(
210
- app,
211
- host="0.0.0.0",
212
- port=7860,
213
- log_level="info"
214
- )
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
+ from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
+ import os
6
+ import uuid
 
7
  import torch
8
+ import torchaudio
9
+ import base64
10
+ from transformers import AutoModelForCausalLM
11
+ from yarngpt.audiotokenizer import AudioTokenizerV2
12
+ from datetime import datetime, timedelta
13
  import logging
 
 
 
14
 
15
  # Setup logging
16
  logging.basicConfig(level=logging.INFO,
17
  format='%(asctime)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Constants
21
+ CURRENT_TIMESTAMP = "2025-05-21 02:39:34"
22
  CURRENT_USER = "Abdulhameed556"
23
 
24
+ app = FastAPI(title="Nigerian TTS API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Add CORS middleware
27
  app.add_middleware(
28
  CORSMiddleware,
29
  allow_origins=["*"],
 
32
  allow_headers=["*"],
33
  )
34
 
35
+ # Model configuration - Using your Hugging Face model
36
+ model_path = "Hameed13/News_Podcast_Model"
37
+ tokenizer_path = "saheedniyi/YarnGPT2"
38
+ wav_tokenizer_config_path = "/code/models/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
39
+ wav_tokenizer_model_path = "/code/models/wavtokenizer_large_speech_320_24k.ckpt"
40
+
41
+ # Initialize model
42
+ logger.info("Loading YarnGPT model and tokenizer...")
43
+ try:
44
+ audio_tokenizer = AudioTokenizerV2(
45
+ tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
46
+ )
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_path,
49
+ torch_dtype="auto",
50
+ token=os.getenv('HF_TOKEN') # In case the model requires authentication
51
+ ).to(audio_tokenizer.device)
52
+ logger.info("Model loaded successfully!")
53
+ except Exception as e:
54
+ logger.error(f"Error loading model: {e}")
55
+ raise
56
+
57
+ # Available voices and languages
58
+ AVAILABLE_VOICES = {
59
+ "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
60
+ "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
61
+ }
62
+ AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
63
+
64
  class TTSRequest(BaseModel):
65
  text: str
66
+ language: str = "english"
67
+ voice: str = "idera"
 
 
 
68
 
69
  class TTSResponse(BaseModel):
70
+ audio_base64: str
71
+ audio_url: str
72
+ text: str
73
+ voice: str
74
+ language: str
75
 
76
  @app.get("/")
77
+ async def root():
78
+ """API health check and info"""
79
  return {
80
+ "status": "ok",
81
+ "message": "Nigerian TTS API is running",
82
+ "available_languages": AVAILABLE_LANGUAGES,
83
+ "available_voices": AVAILABLE_VOICES,
84
+ "model_path": model_path,
85
  "timestamp": CURRENT_TIMESTAMP,
86
+ "user": CURRENT_USER
 
 
87
  }
88
 
89
  @app.post("/tts", response_model=TTSResponse)
90
+ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
91
+ """Convert text to Nigerian-accented speech"""
92
+
93
+ # Validate inputs
94
+ if request.language not in AVAILABLE_LANGUAGES:
95
+ raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}")
96
+
97
+ all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
98
+ if request.voice not in all_voices:
99
+ raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}")
100
+
101
+ # Generate unique filename
102
+ audio_id = str(uuid.uuid4())
103
+ output_path = f"audio_files/{audio_id}.wav"
104
+ os.makedirs("audio_files", exist_ok=True)
105
+
106
  try:
107
+ # Create prompt and generate audio
108
+ prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice)
109
+ input_ids = audio_tokenizer.tokenize_prompt(prompt)
110
+
111
+ output = model.generate(
112
+ input_ids=input_ids,
113
+ temperature=0.1,
114
+ repetition_penalty=1.1,
115
+ max_length=4000,
 
 
 
 
 
 
 
 
 
116
  )
117
+
118
+ codes = audio_tokenizer.get_codes(output)
119
+ audio = audio_tokenizer.get_audio(codes)
120
+
121
+ # Save audio file
122
+ torchaudio.save(output_path, audio, sample_rate=24000)
123
+
124
+ # Read the file and encode as base64
125
+ with open(output_path, "rb") as audio_file:
126
+ audio_bytes = audio_file.read()
127
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
128
+
129
+ # Clean up old files after a while
130
+ background_tasks.add_task(cleanup_old_files)
131
+
132
  return TTSResponse(
133
+ audio_base64=audio_base64,
134
+ audio_url=f"/audio/{audio_id}.wav",
135
+ text=request.text,
136
+ voice=request.voice,
137
+ language=request.language
138
  )
139
+
140
  except Exception as e:
141
+ logger.error(f"Error generating audio: {e}")
142
+ raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
143
+
144
+ @app.get("/audio/{filename}")
145
+ async def get_audio(filename: str):
146
+ file_path = f"audio_files/{filename}"
147
+ if not os.path.exists(file_path):
148
+ raise HTTPException(status_code=404, detail="Audio file not found")
149
+ return FileResponse(file_path, media_type="audio/wav")
150
+
151
+ def cleanup_old_files():
152
+ """Delete audio files older than 6 hours to manage disk space"""
153
+ try:
154
+ now = datetime.now()
155
+ audio_dir = "audio_files"
156
 
157
+ if not os.path.exists(audio_dir):
158
+ return
159
+
160
+ for filename in os.listdir(audio_dir):
161
+ if not filename.endswith(".wav"):
162
+ continue
163
+
164
+ file_path = os.path.join(audio_dir, filename)
165
+ file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
166
+
167
+ # Delete files older than 6 hours
168
+ if now - file_mod_time > timedelta(hours=6):
169
+ os.remove(file_path)
170
+ logger.info(f"Deleted old audio file: {filename}")
171
+ except Exception as e:
172
+ logger.error(f"Error cleaning up old files: {e}")
 
 
 
173
 
174
  if __name__ == "__main__":
175
  import uvicorn
176
+ uvicorn.run(app, host="0.0.0.0", port=7860)