Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import gradio as gr | |
| # Configure logging FIRST | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| stream=sys.stdout | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Add project root to path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| # Import detector using model loader | |
| try: | |
| from diagnosis.ai_engine.model_loader import get_stutter_detector | |
| logger.info("✅ Successfully imported model loader") | |
| except ImportError as e: | |
| logger.error(f"❌ Failed to import model loader: {e}") | |
| raise | |
| # Initialize FastAPI | |
| app = FastAPI( | |
| title="Stutter Detector API", | |
| description="Speech analysis using Wav2Vec2 models for stutter detection", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global detector instance | |
| detector = None | |
| async def startup_event(): | |
| """Load models on startup""" | |
| global detector | |
| try: | |
| logger.info("🚀 Startup event: Loading AI models...") | |
| detector = get_stutter_detector() | |
| logger.info("✅ Models loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load models: {e}", exc_info=True) | |
| raise | |
| def gradio_analyze(audio_path, transcript=""): | |
| """Analyze audio for stuttering using Gradio interface""" | |
| if not detector: | |
| return {"error": "Models not loaded yet. Please try again later."} | |
| try: | |
| result = detector.analyze_audio(audio_path, transcript) | |
| return result | |
| except Exception as e: | |
| return {"error": f"Analysis failed: {str(e)}"} | |
| # Create Gradio interface | |
| gradio_app = gr.Interface( | |
| fn=gradio_analyze, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Upload Audio File"), | |
| gr.Textbox(label="Optional Transcript", placeholder="Enter expected transcript here...", lines=2) | |
| ], | |
| outputs=gr.JSON(label="Analysis Results"), | |
| title="Stutter Detection", | |
| description="Upload an audio file and optionally provide a transcript to analyze for stuttering." | |
| ) | |
| # Mount Gradio app to FastAPI at root path | |
| gr.mount_gradio_app(app, gradio_app, path="/") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| from datetime import datetime | |
| return { | |
| "status": "healthy", | |
| "models_loaded": detector is not None, | |
| "timestamp": datetime.utcnow().isoformat() + "Z" | |
| } | |
| async def analyze_audio( | |
| audio: UploadFile = File(...), | |
| transcript: str = Form("") | |
| ): | |
| """ | |
| Analyze audio file for stuttering | |
| Parameters: | |
| - audio: WAV or MP3 audio file | |
| - transcript: Optional expected transcript | |
| Returns: Complete stutter analysis results | |
| """ | |
| temp_file = None | |
| try: | |
| if not detector: | |
| raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.") | |
| logger.info(f"📥 Processing: {audio.filename}") | |
| # Create temp directory if needed | |
| temp_dir = "/tmp/stutter_analysis" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| # Save uploaded file | |
| temp_file = os.path.join(temp_dir, audio.filename) | |
| content = await audio.read() | |
| with open(temp_file, "wb") as f: | |
| f.write(content) | |
| logger.info(f"📂 Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)") | |
| # Analyze | |
| logger.info(f"🔄 Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'") | |
| result = detector.analyze_audio(temp_file, transcript) | |
| actual = result.get('actual_transcript', '') | |
| target = result.get('target_transcript', '') | |
| logger.info(f"✅ Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%") | |
| logger.info(f"📝 Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})") | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ Error during analysis: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") | |
| finally: | |
| # Cleanup | |
| if temp_file and os.path.exists(temp_file): | |
| try: | |
| os.remove(temp_file) | |
| logger.info(f"🧹 Cleaned up: {temp_file}") | |
| except Exception as e: | |
| logger.warning(f"Could not clean up {temp_file}: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("🚀 Starting SLAQ Stutter Detector API...") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |