hivecorp's picture
Update app.py
68958ab verified
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from typing import Optional, Dict, Any, List, Tuple
import time
import uvicorn
from datetime import datetime
import psutil
import asyncio
import edge_tts
from pydub import AudioSegment
import os
import uuid
import tempfile
from concurrent.futures import ThreadPoolExecutor
# Initialize FastAPI app
app = FastAPI(
title="TTS API Service",
description="Text-to-Speech API with real-time status monitoring",
version="1.0.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Core functionality (moved from app.py)
class TTSError(Exception):
pass
class FileManager:
def __init__(self):
self.temp_dir = tempfile.mkdtemp(prefix="tts_api_")
self.output_files = []
def get_temp_path(self, prefix: str) -> str:
return os.path.join(self.temp_dir, f"{prefix}_{uuid.uuid4()}")
def cleanup_old_files(self):
for path in self.output_files[:-5]: # Keep only last 5 files
try:
if os.path.exists(path):
os.remove(path)
except Exception:
pass
self.output_files = self.output_files[-5:]
# Global state management
class ProcessingState:
def __init__(self):
self.active_jobs: Dict[str, Dict[str, Any]] = {}
self.file_manager = FileManager()
state = ProcessingState()
# Pydantic models
class TTSRequest(BaseModel):
text: str
voice: str = "en-US-JennyNeural"
pitch: int = 0
rate: int = 0
class HealthResponse(BaseModel):
status: str
timestamp: str
cpu_usage: float
memory_usage: float
active_jobs: int
# Voice options dictionary (simplified)
voice_options = {
"Jenny": "en-US-JennyNeural",
"Guy": "en-US-GuyNeural",
"Ana": "en-US-AnaNeural",
"Aria": "en-US-AriaNeural"
}
async def generate_tts(text: str, voice: str, rate: str, pitch: str) -> Tuple[str, str]:
"""Core TTS generation function"""
try:
audio_path = state.file_manager.get_temp_path("audio") + ".mp3"
tts = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
await tts.save(audio_path)
if not os.path.exists(audio_path):
raise TTSError("Failed to generate audio file")
state.file_manager.output_files.append(audio_path)
state.file_manager.cleanup_old_files()
return audio_path
except Exception as e:
raise TTSError(f"TTS generation failed: {str(e)}")
# API endpoints
@app.post("/api/v1/tts")
async def create_tts(request: TTSRequest, background_tasks: BackgroundTasks):
job_id = f"job_{int(time.time())}_{hash(request.text)}"
state.active_jobs[job_id] = {
"id": job_id,
"status": "queued",
"progress": 0,
"created_at": datetime.now().isoformat(),
"last_update": datetime.now().isoformat(),
"request": request.dict()
}
async def process_tts():
try:
pitch_str = f"{request.pitch:+d}Hz"
rate_str = f"{request.rate:+d}%"
audio_path = await generate_tts(
request.text,
request.voice,
rate_str,
pitch_str
)
state.active_jobs[job_id].update({
"status": "completed",
"progress": 1.0,
"result": {
"audio_path": audio_path
}
})
except Exception as e:
state.active_jobs[job_id].update({
"status": "failed",
"error": str(e)
})
background_tasks.add_task(process_tts)
return {"job_id": job_id, "status": "queued"}
@app.get("/api/v1/status/{job_id}")
async def get_job_status(job_id: str):
if job_id not in state.active_jobs:
raise HTTPException(status_code=404, detail="Job not found")
return state.active_jobs[job_id]
@app.get("/api/v1/download/{job_id}")
async def download_file(job_id: str):
if job_id not in state.active_jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = state.active_jobs[job_id]
if job["status"] != "completed":
raise HTTPException(status_code=400, detail="Job not completed")
file_path = job["result"]["audio_path"]
return FileResponse(
file_path,
filename=f"tts_output.mp3"
)
@app.get("/api/v1/health", response_model=HealthResponse)
async def health_check():
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"cpu_usage": psutil.cpu_percent(),
"memory_usage": psutil.virtual_memory().percent,
"active_jobs": len(state.active_jobs)
}
@app.get("/api/v1/voices")
async def list_voices():
return {"voices": voice_options}
if __name__ == "__main__":
# Change to use import string format
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)