import logging import os import sys from pathlib import Path from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import gradio as gr # Configure logging FIRST logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', stream=sys.stdout ) logger = logging.getLogger(__name__) # Add project root to path sys.path.insert(0, str(Path(__file__).parent)) # Import detector using model loader try: from diagnosis.ai_engine.model_loader import get_stutter_detector logger.info("โœ… Successfully imported model loader") except ImportError as e: logger.error(f"โŒ Failed to import model loader: {e}") raise # Initialize FastAPI app = FastAPI( title="Stutter Detector API", description="Speech analysis using Wav2Vec2 models for stutter detection", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global detector instance detector = None @app.on_event("startup") async def startup_event(): """Load models on startup""" global detector try: logger.info("๐Ÿš€ Startup event: Loading AI models...") detector = get_stutter_detector() logger.info("โœ… Models loaded successfully!") except Exception as e: logger.error(f"โŒ Failed to load models: {e}", exc_info=True) raise def gradio_analyze(audio_path, transcript=""): """Analyze audio for stuttering using Gradio interface""" if not detector: return {"error": "Models not loaded yet. Please try again later."} try: result = detector.analyze_audio(audio_path, transcript) return result except Exception as e: return {"error": f"Analysis failed: {str(e)}"} # Create Gradio interface gradio_app = gr.Interface( fn=gradio_analyze, inputs=[ gr.Audio(type="filepath", label="Upload Audio File"), gr.Textbox(label="Optional Transcript", placeholder="Enter expected transcript here...", lines=2) ], outputs=gr.JSON(label="Analysis Results"), title="Stutter Detection", description="Upload an audio file and optionally provide a transcript to analyze for stuttering." ) # Mount Gradio app to FastAPI at root path gr.mount_gradio_app(app, gradio_app, path="/") @app.get("/health") async def health_check(): """Health check endpoint""" from datetime import datetime return { "status": "healthy", "models_loaded": detector is not None, "timestamp": datetime.utcnow().isoformat() + "Z" } @app.post("/analyze") async def analyze_audio( audio: UploadFile = File(...), transcript: str = Form("") ): """ Analyze audio file for stuttering Parameters: - audio: WAV or MP3 audio file - transcript: Optional expected transcript Returns: Complete stutter analysis results """ temp_file = None try: if not detector: raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.") logger.info(f"๐Ÿ“ฅ Processing: {audio.filename}") # Create temp directory if needed temp_dir = "/tmp/stutter_analysis" os.makedirs(temp_dir, exist_ok=True) # Save uploaded file temp_file = os.path.join(temp_dir, audio.filename) content = await audio.read() with open(temp_file, "wb") as f: f.write(content) logger.info(f"๐Ÿ“‚ Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)") # Analyze logger.info(f"๐Ÿ”„ Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'") result = detector.analyze_audio(temp_file, transcript) actual = result.get('actual_transcript', '') target = result.get('target_transcript', '') logger.info(f"โœ… Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%") logger.info(f"๐Ÿ“ Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})") return result except HTTPException: raise except Exception as e: logger.error(f"โŒ Error during analysis: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") finally: # Cleanup if temp_file and os.path.exists(temp_file): try: os.remove(temp_file) logger.info(f"๐Ÿงน Cleaned up: {temp_file}") except Exception as e: logger.warning(f"Could not clean up {temp_file}: {e}") if __name__ == "__main__": import uvicorn logger.info("๐Ÿš€ Starting SLAQ Stutter Detector API...") uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )