Hameed13 commited on
Commit
cae635c
·
verified ·
1 Parent(s): 4ab0bc1

Upload main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.py +82 -21
main.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, F
2
  from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
 
5
  import os
6
  import sys
7
  import uuid
@@ -36,28 +37,58 @@ AUDIO_DIR = os.path.join(os.getcwd(), "audio_files")
36
  os.makedirs(MODELS_DIR, exist_ok=True)
37
  os.makedirs(AUDIO_DIR, exist_ok=True)
38
 
 
 
 
 
 
 
 
39
  # Download model files if they don't exist
40
  def ensure_model_files():
41
- config_file = os.path.join(MODELS_DIR, "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
42
- model_file = os.path.join(MODELS_DIR, "wavtokenizer_large_speech_320_24k.ckpt")
43
-
44
- if not os.path.exists(config_file):
45
- print("Downloading config file...")
46
- os.system(f"wget -O {config_file} https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
47
 
48
- if not os.path.exists(model_file):
49
- print("Downloading model file...")
50
- os.system(f"gdown -O {model_file} 1-ASeEkrn4HY49yZWHTASgfGFNXdVnLTt")
51
-
52
- return os.path.exists(config_file) and os.path.exists(model_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Initialize YarnGPT
55
  def initialize_yarngpt():
56
  try:
57
  from yarngpt.generate import TextToSpeech
58
  tts = TextToSpeech(
59
- wavtokenizer_config_path=os.path.join(MODELS_DIR, "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"),
60
- wavtokenizer_ckpt_path=os.path.join(MODELS_DIR, "wavtokenizer_large_speech_320_24k.ckpt")
61
  )
62
  return tts
63
  except Exception as e:
@@ -72,7 +103,12 @@ class TextRequest(BaseModel):
72
  # Health check endpoint
73
  @app.get("/")
74
  def read_root():
75
- return {"status": "Nigerian Text-to-Speech API is running"}
 
 
 
 
 
76
 
77
  # Text to speech endpoint
78
  @app.post("/tts")
@@ -80,12 +116,18 @@ async def text_to_speech(request: TextRequest):
80
  try:
81
  # Ensure model files are available
82
  if not ensure_model_files():
83
- raise HTTPException(status_code=500, detail="Failed to download model files")
 
 
 
84
 
85
  # Initialize YarnGPT
86
  tts = initialize_yarngpt()
87
  if not tts:
88
- raise HTTPException(status_code=500, detail="Failed to initialize YarnGPT")
 
 
 
89
 
90
  # Generate audio
91
  audio_file_id = str(uuid.uuid4())
@@ -100,22 +142,41 @@ async def text_to_speech(request: TextRequest):
100
  filename=f"{audio_file_id}.wav"
101
  )
102
  except Exception as e:
103
- raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
 
 
 
104
 
105
  # List available files
106
  @app.get("/list_files")
107
  def list_files():
108
- files = os.listdir(AUDIO_DIR)
109
- return {"files": files}
 
 
 
 
 
 
110
 
111
  # Get audio file by id
112
  @app.get("/audio/{file_id}")
113
  def get_audio(file_id: str):
114
- file_path = os.path.join(AUDIO_DIR, f"{file_id}")
115
  if not os.path.exists(file_path):
116
- raise HTTPException(status_code=404, detail="Audio file not found")
 
 
 
117
  return FileResponse(file_path, media_type="audio/wav")
118
 
 
 
 
 
 
 
 
119
  if __name__ == "__main__":
120
  import uvicorn
121
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
2
  from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
+ from huggingface_hub import hf_hub_download
6
  import os
7
  import sys
8
  import uuid
 
37
  os.makedirs(MODELS_DIR, exist_ok=True)
38
  os.makedirs(AUDIO_DIR, exist_ok=True)
39
 
40
+ # Model configuration
41
+ MODEL_CONFIG = {
42
+ "config_file": "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml",
43
+ "model_file": "wavtokenizer_large_speech_320_24k.ckpt",
44
+ "repo_id": "Hameed13/nigerian-tts-model"
45
+ }
46
+
47
  # Download model files if they don't exist
48
  def ensure_model_files():
49
+ config_file = os.path.join(MODELS_DIR, MODEL_CONFIG["config_file"])
50
+ model_file = os.path.join(MODELS_DIR, MODEL_CONFIG["model_file"])
 
 
 
 
51
 
52
+ try:
53
+ if not os.path.exists(config_file):
54
+ print("Copying config file...")
55
+ source_config = MODEL_CONFIG["config_file"]
56
+ if os.path.exists(source_config):
57
+ shutil.copy(source_config, config_file)
58
+ else:
59
+ print(f"Config file not found: {source_config}")
60
+ return False
61
+
62
+ if not os.path.exists(model_file):
63
+ print("Downloading model file from Hugging Face Hub...")
64
+ hf_token = os.environ.get("HF_TOKEN")
65
+ if not hf_token:
66
+ print("HF_TOKEN environment variable not set")
67
+ return False
68
+
69
+ try:
70
+ hf_hub_download(
71
+ repo_id=MODEL_CONFIG["repo_id"],
72
+ filename=MODEL_CONFIG["model_file"],
73
+ local_dir=MODELS_DIR,
74
+ token=hf_token
75
+ )
76
+ except Exception as e:
77
+ print(f"Error downloading model file: {e}")
78
+ return False
79
+
80
+ return os.path.exists(config_file) and os.path.exists(model_file)
81
+ except Exception as e:
82
+ print(f"Error in ensure_model_files: {e}")
83
+ return False
84
 
85
  # Initialize YarnGPT
86
  def initialize_yarngpt():
87
  try:
88
  from yarngpt.generate import TextToSpeech
89
  tts = TextToSpeech(
90
+ wavtokenizer_config_path=os.path.join(MODELS_DIR, MODEL_CONFIG["config_file"]),
91
+ wavtokenizer_ckpt_path=os.path.join(MODELS_DIR, MODEL_CONFIG["model_file"])
92
  )
93
  return tts
94
  except Exception as e:
 
103
  # Health check endpoint
104
  @app.get("/")
105
  def read_root():
106
+ model_status = "available" if os.path.exists(os.path.join(MODELS_DIR, MODEL_CONFIG["model_file"])) else "not available"
107
+ return {
108
+ "status": "Nigerian Text-to-Speech API is running",
109
+ "model_status": model_status,
110
+ "timestamp": "2025-04-22 04:45:43"
111
+ }
112
 
113
  # Text to speech endpoint
114
  @app.post("/tts")
 
116
  try:
117
  # Ensure model files are available
118
  if not ensure_model_files():
119
+ raise HTTPException(
120
+ status_code=500,
121
+ detail="Failed to download or locate model files. Please check logs for details."
122
+ )
123
 
124
  # Initialize YarnGPT
125
  tts = initialize_yarngpt()
126
  if not tts:
127
+ raise HTTPException(
128
+ status_code=500,
129
+ detail="Failed to initialize YarnGPT. Please check logs for details."
130
+ )
131
 
132
  # Generate audio
133
  audio_file_id = str(uuid.uuid4())
 
142
  filename=f"{audio_file_id}.wav"
143
  )
144
  except Exception as e:
145
+ raise HTTPException(
146
+ status_code=500,
147
+ detail=f"Error generating audio: {str(e)}"
148
+ )
149
 
150
  # List available files
151
  @app.get("/list_files")
152
  def list_files():
153
+ try:
154
+ files = [f for f in os.listdir(AUDIO_DIR) if f.endswith('.wav')]
155
+ return {"files": files}
156
+ except Exception as e:
157
+ raise HTTPException(
158
+ status_code=500,
159
+ detail=f"Error listing files: {str(e)}"
160
+ )
161
 
162
  # Get audio file by id
163
  @app.get("/audio/{file_id}")
164
  def get_audio(file_id: str):
165
+ file_path = os.path.join(AUDIO_DIR, file_id)
166
  if not os.path.exists(file_path):
167
+ raise HTTPException(
168
+ status_code=404,
169
+ detail=f"Audio file not found: {file_id}"
170
+ )
171
  return FileResponse(file_path, media_type="audio/wav")
172
 
173
+ # Add startup event to ensure model is downloaded when the container starts
174
+ @app.on_event("startup")
175
+ async def startup_event():
176
+ print("Starting up Nigerian Text-to-Speech API...")
177
+ if not ensure_model_files():
178
+ print("Warning: Failed to initialize model files during startup")
179
+
180
  if __name__ == "__main__":
181
  import uvicorn
182
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)