anfastech's picture
Fix: Gradio UI (default landing page) #2
e765887
import logging
import os
import sys
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import gradio as gr
# Configure logging FIRST
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
# Import detector using model loader
try:
from diagnosis.ai_engine.model_loader import get_stutter_detector
logger.info("✅ Successfully imported model loader")
except ImportError as e:
logger.error(f"❌ Failed to import model loader: {e}")
raise
# Initialize FastAPI
app = FastAPI(
title="Stutter Detector API",
description="Speech analysis using Wav2Vec2 models for stutter detection",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global detector instance
detector = None
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
global detector
try:
logger.info("🚀 Startup event: Loading AI models...")
detector = get_stutter_detector()
logger.info("✅ Models loaded successfully!")
except Exception as e:
logger.error(f"❌ Failed to load models: {e}", exc_info=True)
raise
def gradio_analyze(audio_path, transcript=""):
"""Analyze audio for stuttering using Gradio interface"""
if not detector:
return {"error": "Models not loaded yet. Please try again later."}
try:
result = detector.analyze_audio(audio_path, transcript)
return result
except Exception as e:
return {"error": f"Analysis failed: {str(e)}"}
# Create Gradio interface
gradio_app = gr.Interface(
fn=gradio_analyze,
inputs=[
gr.Audio(type="filepath", label="Upload Audio File"),
gr.Textbox(label="Optional Transcript", placeholder="Enter expected transcript here...", lines=2)
],
outputs=gr.JSON(label="Analysis Results"),
title="Stutter Detection",
description="Upload an audio file and optionally provide a transcript to analyze for stuttering."
)
# Mount Gradio app to FastAPI at root path
gr.mount_gradio_app(app, gradio_app, path="/")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
from datetime import datetime
return {
"status": "healthy",
"models_loaded": detector is not None,
"timestamp": datetime.utcnow().isoformat() + "Z"
}
@app.post("/analyze")
async def analyze_audio(
audio: UploadFile = File(...),
transcript: str = Form("")
):
"""
Analyze audio file for stuttering
Parameters:
- audio: WAV or MP3 audio file
- transcript: Optional expected transcript
Returns: Complete stutter analysis results
"""
temp_file = None
try:
if not detector:
raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.")
logger.info(f"📥 Processing: {audio.filename}")
# Create temp directory if needed
temp_dir = "/tmp/stutter_analysis"
os.makedirs(temp_dir, exist_ok=True)
# Save uploaded file
temp_file = os.path.join(temp_dir, audio.filename)
content = await audio.read()
with open(temp_file, "wb") as f:
f.write(content)
logger.info(f"📂 Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
# Analyze
logger.info(f"🔄 Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
result = detector.analyze_audio(temp_file, transcript)
actual = result.get('actual_transcript', '')
target = result.get('target_transcript', '')
logger.info(f"✅ Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
logger.info(f"📝 Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})")
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error during analysis: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
finally:
# Cleanup
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
logger.info(f"🧹 Cleaned up: {temp_file}")
except Exception as e:
logger.warning(f"Could not clean up {temp_file}: {e}")
if __name__ == "__main__":
import uvicorn
logger.info("🚀 Starting SLAQ Stutter Detector API...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)