Spaces:
Build error
Build error
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 2 |
from fastapi.responses import FileResponse, JSONResponse
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from pydantic import BaseModel
|
|
@@ -12,6 +12,7 @@ import logging
|
|
| 12 |
import traceback
|
| 13 |
from typing import Optional
|
| 14 |
import torch
|
|
|
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(level=logging.INFO,
|
|
@@ -23,6 +24,7 @@ os.environ["OUTETTS_NO_PORTAUDIO"] = "1"
|
|
| 23 |
|
| 24 |
# Import the TextToSpeech class from generate.py
|
| 25 |
try:
|
|
|
|
| 26 |
from generate import TextToSpeech
|
| 27 |
logger.info("Successfully imported TextToSpeech class from generate.py")
|
| 28 |
except ImportError as e:
|
|
@@ -42,7 +44,7 @@ app.add_middleware(
|
|
| 42 |
allow_headers=["*"],
|
| 43 |
)
|
| 44 |
|
| 45 |
-
# YarnGPT TTS configuration
|
| 46 |
MODEL_CONFIG = {
|
| 47 |
"model_name_or_path": "yarngpt/yarn-tts-demo",
|
| 48 |
"processor_name_or_path": "yarngpt/yarn-tts-demo"
|
|
@@ -144,23 +146,21 @@ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks)
|
|
| 144 |
|
| 145 |
logger.info(f"Processing TTS request: '{request.text[:50]}...' with voice '{request.voice}' and language '{request.language}'")
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# Generate speech
|
| 148 |
try:
|
| 149 |
-
# Map the language/voice to accent
|
| 150 |
-
# For simplicity we're using nigerian accent for all, but this could be enhanced
|
| 151 |
-
accent = "nigerian"
|
| 152 |
-
|
| 153 |
-
# Generate audio data
|
| 154 |
audio_data, sample_rate = yarngpt.tts(
|
| 155 |
text=request.text,
|
| 156 |
-
accent=accent,
|
| 157 |
save_path=output_path,
|
| 158 |
speed=request.speed,
|
| 159 |
get_array=True
|
| 160 |
)
|
| 161 |
|
| 162 |
# Convert audio to base64
|
| 163 |
-
import soundfile as sf
|
| 164 |
sf.write(output_path, audio_data, sample_rate)
|
| 165 |
with open(output_path, "rb") as audio_file:
|
| 166 |
audio_bytes = audio_file.read()
|
|
@@ -195,7 +195,7 @@ async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks)
|
|
| 195 |
traceback.print_exc()
|
| 196 |
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
|
| 197 |
|
| 198 |
-
# File serving endpoint (for compatibility
|
| 199 |
@app.get("/audio/{filename}")
|
| 200 |
async def get_audio(filename: str):
|
| 201 |
file_path = os.path.join(AUDIO_DIR, filename)
|
|
@@ -203,7 +203,6 @@ async def get_audio(filename: str):
|
|
| 203 |
raise HTTPException(status_code=404, detail="Audio file not found")
|
| 204 |
return FileResponse(file_path, media_type="audio/wav")
|
| 205 |
|
| 206 |
-
# For local testing
|
| 207 |
if __name__ == "__main__":
|
| 208 |
import uvicorn
|
| 209 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
| 2 |
from fastapi.responses import FileResponse, JSONResponse
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from pydantic import BaseModel
|
|
|
|
| 12 |
import traceback
|
| 13 |
from typing import Optional
|
| 14 |
import torch
|
| 15 |
+
import soundfile as sf
|
| 16 |
|
| 17 |
# Configure logging
|
| 18 |
logging.basicConfig(level=logging.INFO,
|
|
|
|
| 24 |
|
| 25 |
# Import the TextToSpeech class from generate.py
|
| 26 |
try:
|
| 27 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 28 |
from generate import TextToSpeech
|
| 29 |
logger.info("Successfully imported TextToSpeech class from generate.py")
|
| 30 |
except ImportError as e:
|
|
|
|
| 44 |
allow_headers=["*"],
|
| 45 |
)
|
| 46 |
|
| 47 |
+
# YarnGPT TTS configuration
|
| 48 |
MODEL_CONFIG = {
|
| 49 |
"model_name_or_path": "yarngpt/yarn-tts-demo",
|
| 50 |
"processor_name_or_path": "yarngpt/yarn-tts-demo"
|
|
|
|
| 146 |
|
| 147 |
logger.info(f"Processing TTS request: '{request.text[:50]}...' with voice '{request.voice}' and language '{request.language}'")
|
| 148 |
|
| 149 |
+
# Create prompt from voice and language
|
| 150 |
+
# This adapts to the colab-style API even though we're using a different backend
|
| 151 |
+
accent = request.language if request.language in ["nigerian"] else "nigerian"
|
| 152 |
+
|
| 153 |
# Generate speech
|
| 154 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
audio_data, sample_rate = yarngpt.tts(
|
| 156 |
text=request.text,
|
| 157 |
+
accent=accent,
|
| 158 |
save_path=output_path,
|
| 159 |
speed=request.speed,
|
| 160 |
get_array=True
|
| 161 |
)
|
| 162 |
|
| 163 |
# Convert audio to base64
|
|
|
|
| 164 |
sf.write(output_path, audio_data, sample_rate)
|
| 165 |
with open(output_path, "rb") as audio_file:
|
| 166 |
audio_bytes = audio_file.read()
|
|
|
|
| 195 |
traceback.print_exc()
|
| 196 |
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
|
| 197 |
|
| 198 |
+
# File serving endpoint (for backward compatibility)
|
| 199 |
@app.get("/audio/{filename}")
|
| 200 |
async def get_audio(filename: str):
|
| 201 |
file_path = os.path.join(AUDIO_DIR, filename)
|
|
|
|
| 203 |
raise HTTPException(status_code=404, detail="Audio file not found")
|
| 204 |
return FileResponse(file_path, media_type="audio/wav")
|
| 205 |
|
|
|
|
| 206 |
if __name__ == "__main__":
|
| 207 |
import uvicorn
|
| 208 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|