Hameed13 commited on
Commit
91499fa
·
verified ·
1 Parent(s): 3f6fb88

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -11
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
@@ -12,6 +12,7 @@ import logging
12
  import traceback
13
  from typing import Optional
14
  import torch
 
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO,
@@ -23,6 +24,7 @@ os.environ["OUTETTS_NO_PORTAUDIO"] = "1"
23
 
24
  # Import the TextToSpeech class from generate.py
25
  try:
 
26
  from generate import TextToSpeech
27
  logger.info("Successfully imported TextToSpeech class from generate.py")
28
  except ImportError as e:
@@ -42,7 +44,7 @@ app.add_middleware(
42
  allow_headers=["*"],
43
  )
44
 
45
- # YarnGPT TTS configuration - This can be adjusted based on model availability
46
  MODEL_CONFIG = {
47
  "model_name_or_path": "yarngpt/yarn-tts-demo",
48
  "processor_name_or_path": "yarngpt/yarn-tts-demo"
@@ -144,23 +146,21 @@ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks)
144
 
145
  logger.info(f"Processing TTS request: '{request.text[:50]}...' with voice '{request.voice}' and language '{request.language}'")
146
 
 
 
 
 
147
  # Generate speech
148
  try:
149
- # Map the language/voice to accent
150
- # For simplicity we're using nigerian accent for all, but this could be enhanced
151
- accent = "nigerian"
152
-
153
- # Generate audio data
154
  audio_data, sample_rate = yarngpt.tts(
155
  text=request.text,
156
- accent=accent,
157
  save_path=output_path,
158
  speed=request.speed,
159
  get_array=True
160
  )
161
 
162
  # Convert audio to base64
163
- import soundfile as sf
164
  sf.write(output_path, audio_data, sample_rate)
165
  with open(output_path, "rb") as audio_file:
166
  audio_bytes = audio_file.read()
@@ -195,7 +195,7 @@ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks)
195
  traceback.print_exc()
196
  raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
197
 
198
- # File serving endpoint (for compatibility with direct requests)
199
  @app.get("/audio/{filename}")
200
  async def get_audio(filename: str):
201
  file_path = os.path.join(AUDIO_DIR, filename)
@@ -203,7 +203,6 @@ async def get_audio(filename: str):
203
  raise HTTPException(status_code=404, detail="Audio file not found")
204
  return FileResponse(file_path, media_type="audio/wav")
205
 
206
- # For local testing
207
  if __name__ == "__main__":
208
  import uvicorn
209
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
2
  from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
 
12
  import traceback
13
  from typing import Optional
14
  import torch
15
+ import soundfile as sf
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO,
 
24
 
25
  # Import the TextToSpeech class from generate.py
26
  try:
27
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
28
  from generate import TextToSpeech
29
  logger.info("Successfully imported TextToSpeech class from generate.py")
30
  except ImportError as e:
 
44
  allow_headers=["*"],
45
  )
46
 
47
+ # YarnGPT TTS configuration
48
  MODEL_CONFIG = {
49
  "model_name_or_path": "yarngpt/yarn-tts-demo",
50
  "processor_name_or_path": "yarngpt/yarn-tts-demo"
 
146
 
147
  logger.info(f"Processing TTS request: '{request.text[:50]}...' with voice '{request.voice}' and language '{request.language}'")
148
 
149
+ # Create prompt from voice and language
150
+ # This adapts to the colab-style API even though we're using a different backend
151
+ accent = request.language if request.language in ["nigerian"] else "nigerian"
152
+
153
  # Generate speech
154
  try:
 
 
 
 
 
155
  audio_data, sample_rate = yarngpt.tts(
156
  text=request.text,
157
+ accent=accent,
158
  save_path=output_path,
159
  speed=request.speed,
160
  get_array=True
161
  )
162
 
163
  # Convert audio to base64
 
164
  sf.write(output_path, audio_data, sample_rate)
165
  with open(output_path, "rb") as audio_file:
166
  audio_bytes = audio_file.read()
 
195
  traceback.print_exc()
196
  raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
197
 
198
+ # File serving endpoint (for backward compatibility)
199
  @app.get("/audio/{filename}")
200
  async def get_audio(filename: str):
201
  file_path = os.path.join(AUDIO_DIR, filename)
 
203
  raise HTTPException(status_code=404, detail="Audio file not found")
204
  return FileResponse(file_path, media_type="audio/wav")
205
 
 
206
  if __name__ == "__main__":
207
  import uvicorn
208
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)