Hameed13 commited on
Commit
e0814ef
·
verified ·
1 Parent(s): 3e6d1d6

Upload main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.py +224 -1
main.py CHANGED
@@ -1 +1,224 @@
1
- [Previous main.py content...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
2
+ from fastapi.responses import FileResponse, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+ import sys
8
+ import uuid
9
+ import torch
10
+ import torchaudio
11
+ import base64
12
+ from io import BytesIO
13
+ import shutil
14
+ import importlib.util
15
+ import subprocess
16
+ from datetime import datetime
17
+
18
+ # Add YarnGPT to path
19
+ sys.path.append(os.path.join(os.getcwd(), "yarngpt"))
20
+
21
+ # Initialize FastAPI
22
+ app = FastAPI(
23
+ title="Nigerian Text-to-Speech API",
24
+ version="1.0.0",
25
+ description="A FastAPI service for Nigerian Text-to-Speech generation"
26
+ )
27
+
28
+ # Configure CORS
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # Models directory
38
+ MODELS_DIR = os.path.join(os.getcwd(), "models")
39
+ AUDIO_DIR = os.path.join(os.getcwd(), "audio_files")
40
+
41
+ # Ensure directories exist
42
+ os.makedirs(MODELS_DIR, exist_ok=True)
43
+ os.makedirs(AUDIO_DIR, exist_ok=True)
44
+
45
+ # Model configuration
46
+ MODEL_CONFIG = {
47
+ "config_file": "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml",
48
+ "model_file": "wavtokenizer_large_speech_320_24k.ckpt",
49
+ "repo_id": "Hameed13/nigerian-tts-model"
50
+ }
51
+
52
+ def get_current_timestamp():
53
+ return datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
54
+
55
+ # Download model files if they don't exist
56
+ def ensure_model_files():
57
+ config_file = os.path.join(MODELS_DIR, MODEL_CONFIG["config_file"])
58
+ model_file = os.path.join(MODELS_DIR, MODEL_CONFIG["model_file"])
59
+
60
+ try:
61
+ # First check for HF_TOKEN
62
+ hf_token = os.environ.get("HF_TOKEN")
63
+ if not hf_token:
64
+ print(f"[{get_current_timestamp()}] HF_TOKEN environment variable not set")
65
+ return False
66
+
67
+ # Check and download config file
68
+ if not os.path.exists(config_file):
69
+ print(f"[{get_current_timestamp()}] Downloading config file from Hugging Face Hub...")
70
+ try:
71
+ hf_hub_download(
72
+ repo_id=MODEL_CONFIG["repo_id"],
73
+ filename=MODEL_CONFIG["config_file"],
74
+ local_dir=MODELS_DIR,
75
+ token=hf_token
76
+ )
77
+ except Exception as e:
78
+ print(f"[{get_current_timestamp()}] Error downloading config file: {e}")
79
+ return False
80
+
81
+ # Check and download model file
82
+ if not os.path.exists(model_file):
83
+ print(f"[{get_current_timestamp()}] Downloading model file from Hugging Face Hub...")
84
+ try:
85
+ hf_hub_download(
86
+ repo_id=MODEL_CONFIG["repo_id"],
87
+ filename=MODEL_CONFIG["model_file"],
88
+ local_dir=MODELS_DIR,
89
+ token=hf_token
90
+ )
91
+ except Exception as e:
92
+ print(f"[{get_current_timestamp()}] Error downloading model file: {e}")
93
+ return False
94
+
95
+ return os.path.exists(config_file) and os.path.exists(model_file)
96
+ except Exception as e:
97
+ print(f"[{get_current_timestamp()}] Error in ensure_model_files: {e}")
98
+ return False
99
+
100
+ # Initialize YarnGPT
101
+ def initialize_yarngpt():
102
+ try:
103
+ from yarngpt.generate import TextToSpeech
104
+ tts = TextToSpeech(
105
+ wavtokenizer_config_path=os.path.join(MODELS_DIR, MODEL_CONFIG["config_file"]),
106
+ wavtokenizer_ckpt_path=os.path.join(MODELS_DIR, MODEL_CONFIG["model_file"])
107
+ )
108
+ return tts
109
+ except Exception as e:
110
+ print(f"[{get_current_timestamp()}] Error initializing YarnGPT: {e}")
111
+ return None
112
+
113
+ # Request models
114
+ class TextRequest(BaseModel):
115
+ text: str
116
+ accent: str = "nigerian"
117
+
118
+ # Health check endpoint
119
+ @app.get("/")
120
+ def read_root():
121
+ model_file = os.path.join(MODELS_DIR, MODEL_CONFIG["model_file"])
122
+ config_file = os.path.join(MODELS_DIR, MODEL_CONFIG["config_file"])
123
+ hf_token = os.environ.get("HF_TOKEN")
124
+
125
+ status = {
126
+ "status": "Nigerian Text-to-Speech API is running",
127
+ "model_status": {
128
+ "model_file_exists": os.path.exists(model_file),
129
+ "config_file_exists": os.path.exists(config_file),
130
+ "models_dir": MODELS_DIR,
131
+ "hf_token_set": bool(hf_token),
132
+ "hf_token_valid": bool(hf_token and len(hf_token) > 0)
133
+ },
134
+ "timestamp": "2025-04-22 13:39:07",
135
+ "version": "1.0.0",
136
+ "author": "Abdulhameed556"
137
+ }
138
+ return JSONResponse(content=status)
139
+
140
+ # Text to speech endpoint
141
+ @app.post("/tts")
142
+ async def text_to_speech(request: TextRequest):
143
+ try:
144
+ # Check HF_TOKEN first
145
+ if not os.environ.get("HF_TOKEN"):
146
+ raise HTTPException(
147
+ status_code=500,
148
+ detail="HF_TOKEN environment variable not set. Please configure your Hugging Face token."
149
+ )
150
+
151
+ # Ensure model files are available
152
+ if not ensure_model_files():
153
+ raise HTTPException(
154
+ status_code=500,
155
+ detail="Failed to download or locate model files. Please check logs for details."
156
+ )
157
+
158
+ # Initialize YarnGPT
159
+ tts = initialize_yarngpt()
160
+ if not tts:
161
+ raise HTTPException(
162
+ status_code=500,
163
+ detail="Failed to initialize YarnGPT. Please check logs for details."
164
+ )
165
+
166
+ # Generate audio
167
+ audio_file_id = str(uuid.uuid4())
168
+ output_path = os.path.join(AUDIO_DIR, f"{audio_file_id}.wav")
169
+
170
+ print(f"[{get_current_timestamp()}] Generating audio for text: {request.text[:50]}...")
171
+ tts.read_text(request.text, output_path)
172
+
173
+ # Return the audio file
174
+ return FileResponse(
175
+ output_path,
176
+ media_type="audio/wav",
177
+ filename=f"{audio_file_id}.wav"
178
+ )
179
+ except Exception as e:
180
+ print(f"[{get_current_timestamp()}] Error in text_to_speech: {str(e)}")
181
+ raise HTTPException(
182
+ status_code=500,
183
+ detail=f"Error generating audio: {str(e)}"
184
+ )
185
+
186
+ # List available files
187
+ @app.get("/list_files")
188
+ def list_files():
189
+ try:
190
+ files = [f for f in os.listdir(AUDIO_DIR) if f.endswith('.wav')]
191
+ return {
192
+ "files": files,
193
+ "count": len(files),
194
+ "timestamp": get_current_timestamp()
195
+ }
196
+ except Exception as e:
197
+ raise HTTPException(
198
+ status_code=500,
199
+ detail=f"Error listing files: {str(e)}"
200
+ )
201
+
202
+ # Get audio file by id
203
+ @app.get("/audio/{file_id}")
204
+ def get_audio(file_id: str):
205
+ file_path = os.path.join(AUDIO_DIR, file_id)
206
+ if not os.path.exists(file_path):
207
+ raise HTTPException(
208
+ status_code=404,
209
+ detail=f"Audio file not found: {file_id}"
210
+ )
211
+ return FileResponse(file_path, media_type="audio/wav")
212
+
213
+ # Add startup event to ensure model is downloaded when the container starts
214
+ @app.on_event("startup")
215
+ async def startup_event():
216
+ print(f"[{get_current_timestamp()}] Starting up Nigerian Text-to-Speech API...")
217
+ if not os.environ.get("HF_TOKEN"):
218
+ print(f"[{get_current_timestamp()}] Warning: HF_TOKEN environment variable not set")
219
+ if not ensure_model_files():
220
+ print(f"[{get_current_timestamp()}] Warning: Failed to initialize model files during startup")
221
+
222
+ if __name__ == "__main__":
223
+ import uvicorn
224
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)