kid-coach-api / main.py
akpande2's picture
Update main.py
0e0a1bb verified
"""
Production FastAPI Server for Public Speaking Coach
With LLM Tips and Avatar Voice Support
"""
import os
import shutil
import tempfile
from pathlib import Path
from typing import Optional
import uvicorn
from fastapi import FastAPI, UploadFile, File, HTTPException, status, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from kid_coach_pipeline import EnhancedPublicSpeakingCoach
# ================= APP CONFIGURATION =================
app = FastAPI(
title="Public Speaking Coach API",
description="AI-powered speech analysis with LLM tips and avatar voice",
version="3.0.0"
)
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Audio output directory
AUDIO_OUTPUT_DIR = "/tmp/audio_outputs"
os.makedirs(AUDIO_OUTPUT_DIR, exist_ok=True)
# Mount static files for audio serving
app.mount("/audio", StaticFiles(directory=AUDIO_OUTPUT_DIR), name="audio")
# Global engine instance
coach_engine: Optional[EnhancedPublicSpeakingCoach] = None
# Supported audio formats
SUPPORTED_FORMATS = {
'.wav', '.mp3', '.m4a', '.flac', '.ogg',
'.wma', '.aac', '.mp4', '.webm'
}
# Maximum file size (50MB)
MAX_FILE_SIZE = 50 * 1024 * 1024
# ================= RESPONSE MODELS =================
class HealthResponse(BaseModel):
"""Health check response"""
status: str
engine_loaded: bool
tts_enabled: bool
supported_formats: list
class ErrorResponse(BaseModel):
"""Error response format"""
error: str
detail: Optional[str] = None
# ================= STARTUP/SHUTDOWN =================
@app.on_event("startup")
async def startup_event():
"""Initialize the coach engine on server start"""
global coach_engine
print("\n" + "="*60)
print("🚀 PUBLIC SPEAKING COACH API - STARTING")
print("="*60)
try:
print("\n📦 Loading AI models...")
coach_engine = EnhancedPublicSpeakingCoach(
whisper_model_size="base",
enable_tts=True
)
print("✅ Coach engine ready!")
print("\n" + "="*60)
print("🎤 API is ready to analyze speeches!")
print("="*60 + "\n")
except Exception as e:
print(f"\n❌ STARTUP FAILED: {e}")
print("Server will start but analysis will not work.\n")
coach_engine = None
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on server shutdown"""
print("\n👋 Shutting down Public Speaking Coach API...")
# ================= ENDPOINTS =================
@app.get("/", response_model=HealthResponse)
async def root():
"""Root endpoint - API info"""
return {
"status": "online",
"engine_loaded": coach_engine is not None,
"tts_enabled": coach_engine.tts_enabled if coach_engine else False,
"supported_formats": list(SUPPORTED_FORMATS)
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if coach_engine else "degraded",
"engine_loaded": coach_engine is not None,
"tts_enabled": coach_engine.tts_enabled if coach_engine else False,
"supported_formats": list(SUPPORTED_FORMATS)
}
@app.post("/coach")
async def analyze_speech(file: UploadFile = File(...), avatar_gender: str = Form('male')):
"""
Main endpoint: Upload audio file and receive comprehensive analysis
"""
# Check if engine is loaded
if coach_engine is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Coach engine not initialized. Please contact administrator."
)
# Validate file exists
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No file provided"
)
# Validate filename
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid filename"
)
# Get file extension
file_ext = Path(file.filename).suffix.lower()
# Validate format
if file_ext not in SUPPORTED_FORMATS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported format '{file_ext}'. Supported: {', '.join(SUPPORTED_FORMATS)}"
)
# Create temporary file
temp_file = None
try:
# Read file content
content = await file.read()
# Check file size
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB"
)
# Create temporary file with proper extension
with tempfile.NamedTemporaryFile(
delete=False,
suffix=file_ext
) as temp:
temp.write(content)
temp_file = temp.name
print(f"\n📁 Processing: {file.filename} ({len(content) / 1024:.1f} KB)")
# Run analysis (FIXED: using temp_file instead of audio_path)
result = coach_engine.analyze_speech(temp_file, enable_tts=True, avatar_gender=avatar_gender)
# Check for analysis errors
if "error" in result:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=result["error"]
)
print(f"✅ Analysis complete")
return JSONResponse(content=result)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
# Log unexpected errors
import traceback
print(f"\n❌ ANALYSIS ERROR:")
traceback.print_exc()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Analysis failed: {str(e)}"
)
finally:
# Cleanup temporary file
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
except Exception as e:
print(f"⚠️ Failed to delete temp file: {e}")
@app.post("/analyze")
async def analyze_speech_alias(file: UploadFile = File(...)):
"""Alias endpoint for /coach (for compatibility)"""
return await analyze_speech(file)
@app.get("/audio/{filename}")
async def get_audio(filename: str):
"""Serve generated avatar audio files"""
file_path = os.path.join(AUDIO_OUTPUT_DIR, filename)
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Audio file not found"
)
return FileResponse(
file_path,
media_type="audio/wav",
filename=filename
)
# ================= ERROR HANDLERS =================
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
"""Custom HTTP exception handler"""
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.detail,
"status_code": exc.status_code
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
"""Catch-all exception handler"""
import traceback
traceback.print_exc()
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": "Internal server error",
"detail": str(exc)
}
)
# ================= MAIN =================
if __name__ == "__main__":
# For local development
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_level="info"
)