Hameed13 commited on
Commit
1f406e0
·
verified ·
1 Parent(s): cf42333

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +94 -139
main.py CHANGED
@@ -1,159 +1,114 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
- from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
- import os
6
- import uuid
 
7
  import torch
8
- import torchaudio
9
- import base64
10
- from transformers import AutoModelForCausalLM
11
- from yarngpt.audiotokenizer import AudioTokenizerV2
12
- import uvicorn
13
- from datetime import datetime, timedelta
14
-
15
- app = FastAPI(title="Nigerian TTS API")
16
-
17
- # Add CORS middleware to allow requests from any origin
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
- allow_origins=["*"], # Allows all origins
21
  allow_credentials=True,
22
- allow_methods=["*"], # Allows all methods
23
- allow_headers=["*"], # Allows all headers
24
  )
25
 
26
- # Model configuration paths
27
- tokenizer_path = "saheedniyi/YarnGPT2"
28
- wav_tokenizer_config_path = "./wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
29
- wav_tokenizer_model_path = "./wavtokenizer_large_speech_320_24k.ckpt"
30
-
31
- # Initialize model (only once when the API starts)
32
- print("Loading YarnGPT model and tokenizer...")
33
- audio_tokenizer = AudioTokenizerV2(
34
- tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
35
- )
36
- model = AutoModelForCausalLM.from_pretrained(tokenizer_path, torch_dtype="auto").to(audio_tokenizer.device)
37
- print("Model loaded successfully!")
38
-
39
- # Available voices and languages
40
- AVAILABLE_VOICES = {
41
- "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
42
- "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
43
- }
44
- AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
45
-
46
- # Input validation model
47
  class TTSRequest(BaseModel):
48
  text: str
49
- language: str = "english"
50
- voice: str = "idera"
51
-
52
- # Output model with base64-encoded audio
53
- class TTSResponse(BaseModel):
54
- audio_base64: str # Base64-encoded audio data
55
- audio_url: str # Keep for backward compatibility
56
- text: str
57
- voice: str
58
- language: str
59
 
60
  @app.get("/")
61
- async def root():
62
- """API health check and info"""
63
- return {
64
- "status": "ok",
65
- "message": "Nigerian TTS API is running",
66
- "available_languages": AVAILABLE_LANGUAGES,
67
- "available_voices": AVAILABLE_VOICES
68
- }
69
-
70
- @app.post("/tts", response_model=TTSResponse)
71
- async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
72
- """Convert text to Nigerian-accented speech"""
73
-
74
- # Validate inputs
75
- if request.language not in AVAILABLE_LANGUAGES:
76
- raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}")
77
-
78
- all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
79
- if request.voice not in all_voices:
80
- raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}")
81
-
82
- # Generate unique filename
83
- audio_id = str(uuid.uuid4())
84
- output_path = f"audio_files/{audio_id}.wav"
85
- os.makedirs("audio_files", exist_ok=True)
86
 
 
 
87
  try:
88
- # Create prompt and generate audio
89
- prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice)
90
- input_ids = audio_tokenizer.tokenize_prompt(prompt)
91
-
92
- output = model.generate(
93
- input_ids=input_ids,
94
- temperature=0.1,
95
- repetition_penalty=1.1,
96
- max_length=4000,
 
 
 
 
 
 
 
 
97
  )
98
-
99
- codes = audio_tokenizer.get_codes(output)
100
- audio = audio_tokenizer.get_audio(codes)
101
-
102
- # Save audio file
103
- torchaudio.save(output_path, audio, sample_rate=24000)
104
-
105
- # Read the file and encode as base64
106
- with open(output_path, "rb") as audio_file:
107
- audio_bytes = audio_file.read()
108
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
109
-
110
- # Clean up old files after a while
111
- background_tasks.add_task(cleanup_old_files)
112
-
113
- return TTSResponse(
114
- audio_base64=audio_base64,
115
- audio_url=f"/audio/{audio_id}.wav", # Keep for compatibility
116
- text=request.text,
117
- voice=request.voice,
118
- language=request.language
119
- )
120
-
121
  except Exception as e:
122
- raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
 
123
 
124
- # File serving endpoint (keep for backward compatibility)
125
- @app.get("/audio/{filename}")
126
- async def get_audio(filename: str):
127
- file_path = f"audio_files/{filename}"
128
- if not os.path.exists(file_path):
129
- raise HTTPException(status_code=404, detail="Audio file not found")
130
- return FileResponse(file_path, media_type="audio/wav")
131
-
132
- # Cleanup function to remove old files
133
- def cleanup_old_files():
134
- """Delete audio files older than 6 hours to manage disk space"""
135
- try:
136
- now = datetime.now()
137
- audio_dir = "audio_files"
138
-
139
- if not os.path.exists(audio_dir):
140
- return
141
-
142
- for filename in os.listdir(audio_dir):
143
- if not filename.endswith(".wav"):
144
- continue
145
-
146
- file_path = os.path.join(audio_dir, filename)
147
- file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
148
-
149
- # Delete files older than 6 hours
150
- if now - file_mod_time > timedelta(hours=6):
151
- os.remove(file_path)
152
- print(f"Deleted old audio file: {filename}")
153
- except Exception as e:
154
- print(f"Error cleaning up old files: {e}")
155
 
156
- # For Hugging Face Spaces, we'll use the default port 7860
157
  if __name__ == "__main__":
158
- print("Starting Nigerian TTS API server...")
159
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
+ import sys
6
+ import time
7
+ import numpy as np
8
  import torch
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Set absolute paths for model files
18
+ MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
19
+ MODEL_CHECKPOINT = os.path.join(MODEL_DIR, "wavtokenizer_large_speech_320_24k.ckpt")
20
+ MODEL_CONFIG = os.path.join(MODEL_DIR, "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
21
+
22
+ # Check that model files exist
23
+ if not os.path.exists(MODEL_CHECKPOINT):
24
+ logger.error(f"Model checkpoint not found: {MODEL_CHECKPOINT}")
25
+ raise FileNotFoundError(f"Model checkpoint not found: {MODEL_CHECKPOINT}")
26
+
27
+ if not os.path.exists(MODEL_CONFIG):
28
+ logger.error(f"Model config not found: {MODEL_CONFIG}")
29
+ raise FileNotFoundError(f"Model config not found: {MODEL_CONFIG}")
30
+
31
+ logger.info(f"Loading YarnGPT model from {MODEL_CHECKPOINT} and {MODEL_CONFIG}")
32
+
33
+ # Import TTS modules only after verifying files exist
34
+ try:
35
+ from yarngpt.generate import generate_audio, save_audio
36
+ logger.info("Successfully imported yarngpt modules")
37
+ except ImportError as e:
38
+ logger.error(f"Failed to import YarnGPT modules: {e}")
39
+ raise
40
+
41
+ # Create FastAPI app
42
+ app = FastAPI(title="YarnGPT TTS API")
43
+
44
+ # Configure CORS
45
  app.add_middleware(
46
  CORSMiddleware,
47
+ allow_origins=["*"], # Allow all origins
48
  allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class TTSRequest(BaseModel):
54
  text: str
55
+ temperature: float = 0.2
56
+ top_p: float = 0.7
57
+ top_k: int = 50
58
+ speed: float = 1.0
59
+ seed: int = 42
 
 
 
 
 
60
 
61
  @app.get("/")
62
+ def read_root():
63
+ return {"message": "YarnGPT TTS API is running. Send POST requests to /tts endpoint."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ @app.post("/tts")
66
+ async def text_to_speech(request: TTSRequest):
67
  try:
68
+ logger.info(f"Processing TTS request: {request.text[:50]}...")
69
+
70
+ # Set random seed if provided
71
+ if request.seed is not None:
72
+ torch.manual_seed(request.seed)
73
+ np.random.seed(request.seed)
74
+
75
+ # Generate audio
76
+ start_time = time.time()
77
+ audio = generate_audio(
78
+ request.text,
79
+ checkpoint_path=MODEL_CHECKPOINT,
80
+ config_path=MODEL_CONFIG,
81
+ temperature=request.temperature,
82
+ top_p=request.top_p,
83
+ top_k=request.top_k,
84
+ speed=request.speed
85
  )
86
+
87
+ # Convert audio to base64
88
+ import base64
89
+ import io
90
+ audio_io = io.BytesIO()
91
+ save_audio(audio_io, audio, sample_rate=24000)
92
+ audio_io.seek(0)
93
+ audio_base64 = base64.b64encode(audio_io.read()).decode('utf-8')
94
+
95
+ generation_time = time.time() - start_time
96
+ logger.info(f"Generated audio in {generation_time:.2f} seconds")
97
+
98
+ return {
99
+ "audio": audio_base64,
100
+ "generation_time": generation_time
101
+ }
102
+
 
 
 
 
 
 
103
  except Exception as e:
104
+ logger.error(f"Error generating speech: {str(e)}", exc_info=True)
105
+ raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
106
 
107
+ @app.get("/health")
108
+ def health_check():
109
+ return {"status": "ok", "models_loaded": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # For local testing
112
  if __name__ == "__main__":
113
+ import uvicorn
114
  uvicorn.run(app, host="0.0.0.0", port=7860)