""" Gradio Interface for Speech Pathology Diagnosis This module provides a user-friendly web interface for speech pathology analysis using Gradio. Supports both file upload and microphone input with structured output display. """ import logging import time import tempfile import os from pathlib import Path from typing import Tuple, Optional, Dict, Any import numpy as np import gradio as gr from diagnosis.ai_engine.model_loader import get_inference_pipeline from api.routes import get_phoneme_mapper, get_error_mapper from models.error_taxonomy import ErrorType, SeverityLevel from config import GradioConfig, default_gradio_config logger = logging.getLogger(__name__) # Global inference pipeline instance _inference_pipeline = None def get_inference_pipeline_instance(): """Get or initialize the inference pipeline singleton.""" global _inference_pipeline if _inference_pipeline is None: try: _inference_pipeline = get_inference_pipeline() logger.info("✅ Inference pipeline loaded for Gradio interface") except Exception as e: logger.error(f"❌ Failed to load inference pipeline: {e}", exc_info=True) raise return _inference_pipeline def format_articulation_issues(articulation_scores: list) -> str: """ Format articulation issues from prediction results. Args: articulation_scores: List of articulation predictions Returns: Formatted string describing articulation issues """ if not articulation_scores: return "No articulation data available" # Count occurrences of each articulation type articulation_counts = { "normal": 0, "substitution": 0, "omission": 0, "distortion": 0 } for score in articulation_scores: class_name = score.get("class_name", "normal") if class_name in articulation_counts: articulation_counts[class_name] += 1 total_frames = len(articulation_scores) if total_frames == 0: return "No frames analyzed" # Calculate percentages issues = [] for art_type, count in articulation_counts.items(): if art_type != "normal" and count > 0: percentage = (count / total_frames) * 100 issues.append(f"{art_type.capitalize()}: {percentage:.1f}% ({count}/{total_frames} frames)") if not issues: return "✅ No articulation issues detected - Normal articulation" else: return "⚠️ Articulation Issues Detected:\n" + "\n".join(f" • {issue}" for issue in issues) def analyze_speech( audio_input: Optional[Tuple[int, np.ndarray]], audio_file: Optional[str], expected_text: Optional[str] = None ) -> Tuple[str, str, str, str, str, Dict[str, Any]]: """ Analyze speech audio for fluency and articulation issues. Args: audio_input: Tuple of (sample_rate, audio_array) from microphone audio_file: Path to uploaded audio file Returns: Tuple of (fluency_score_html, articulation_issues, confidence_html, processing_time_html, json_output) """ start_time = time.time() # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:98","message":"analyze_speech entry","data":{"has_audio_file":audio_file is not None,"has_audio_input":audio_input is not None},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion try: # Get inference pipeline # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:102","message":"Getting inference pipeline","data":{},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion pipeline = get_inference_pipeline_instance() # Determine audio source audio_path = None if audio_file is not None and audio_file != "": audio_path = audio_file logger.info(f"Processing uploaded file: {audio_path}") elif audio_input is not None: # Save microphone input to temporary file sample_rate, audio_array = audio_input logger.info(f"Processing microphone input: {len(audio_array)} samples at {sample_rate}Hz") # Create temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_path = temp_file.name temp_file.close() try: # Save audio to temporary file import soundfile as sf sf.write(temp_path, audio_array, sample_rate) audio_path = temp_path logger.info(f"Saved microphone input to: {temp_path}") except Exception as e: logger.error(f"Failed to save microphone audio: {e}") raise ValueError(f"Cannot process microphone audio: {e}") else: return ( "

❌ Error: No audio input provided

", "No audio provided", "N/A", "N/A", {"error": "No audio input provided"} ) # Run batch prediction logger.info(f"Running batch prediction on: {audio_path}") # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:137","message":"Before predict_batch call","data":{"audio_path":audio_path},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion result = pipeline.predict_batch(audio_path, return_timestamps=True, apply_smoothing=True) # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:139","message":"After predict_batch call","data":{"success":True,"num_frames":result.num_frames},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion # Get phoneme and error mappers phoneme_mapper = get_phoneme_mapper() error_mapper = get_error_mapper() # Map phonemes to frames if text provided frame_phonemes = [] if expected_text and phoneme_mapper: try: frame_phonemes = phoneme_mapper.map_text_to_frames( expected_text, num_frames=result.num_frames, audio_duration=result.duration ) logger.info(f"✅ Mapped {len(frame_phonemes)} phonemes to frames") except Exception as e: logger.warning(f"⚠️ Phoneme mapping failed: {e}") frame_phonemes = [''] * result.num_frames else: frame_phonemes = [''] * result.num_frames # Process errors with error mapper errors = [] error_table_rows = [] for i, frame_pred in enumerate(result.frame_predictions): phoneme = frame_phonemes[i] if i < len(frame_phonemes) else '' # Map classifier output to error detail (8-class system) class_id = frame_pred.articulation_class if frame_pred.fluency_label == 'stutter': class_id += 4 # Add 4 for stutter classes (4-7) # Get error detail if error_mapper: try: error_detail = error_mapper.map_classifier_output( class_id=class_id, confidence=frame_pred.confidence, phoneme=phoneme if phoneme else 'unknown', fluency_label=frame_pred.fluency_label ) if error_detail.error_type != ErrorType.NORMAL: errors.append((i, frame_pred.time, error_detail)) # Add to error table severity_level = error_mapper.get_severity_level(error_detail.severity) severity_color = { SeverityLevel.NONE: "green", SeverityLevel.LOW: "orange", SeverityLevel.MEDIUM: "orange", SeverityLevel.HIGH: "red" }.get(severity_level, "gray") error_table_rows.append({ "phoneme": error_detail.phoneme, "time": f"{frame_pred.time:.2f}s", "error_type": error_detail.error_type.value, "wrong_sound": error_detail.wrong_sound or "N/A", "severity": severity_level.value, "severity_color": severity_color, "therapy": error_detail.therapy[:80] + "..." if len(error_detail.therapy) > 80 else error_detail.therapy }) except Exception as e: logger.warning(f"Error mapping failed for frame {i}: {e}") # Calculate processing time processing_time_ms = (time.time() - start_time) * 1000 # Extract metrics from new PhoneLevelResult format aggregate = result.aggregate mean_fluency_stutter = aggregate.get("fluency_score", 0.0) fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage # Count fluent frames fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal') fluent_frames_percentage = (fluent_frames / result.num_frames * 100) if result.num_frames > 0 else 0.0 # Format fluency score with color coding if fluency_percentage >= 80: fluency_color = "green" fluency_emoji = "✅" elif fluency_percentage >= 60: fluency_color = "orange" fluency_emoji = "⚠️" else: fluency_color = "red" fluency_emoji = "❌" fluency_html = f"""

{fluency_emoji} {fluency_percentage:.1f}%

Mean Fluency Score
Fluent Frames: {fluent_frames_percentage:.1f}%

""" # Format articulation issues articulation_class = aggregate.get("articulation_class", 0) articulation_label = aggregate.get("articulation_label", "normal") articulation_text = f"**Dominant Class:** {articulation_label.capitalize()}\n\n" articulation_text += f"**Frame Breakdown:**\n" class_counts = {} for fp in result.frame_predictions: label = fp.articulation_label class_counts[label] = class_counts.get(label, 0) + 1 for label, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True): percentage = (count / result.num_frames * 100) if result.num_frames > 0 else 0.0 articulation_text += f"- {label.capitalize()}: {count} frames ({percentage:.1f}%)\n" # Calculate average confidence avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0 confidence_percentage = avg_confidence * 100 confidence_html = f"""

{confidence_percentage:.1f}%

Overall Confidence

""" # Format processing time processing_time_html = f"""

⏱️ Processing Time: {processing_time_ms:.0f}ms

Analyzed {result.num_frames} frames ({result.duration:.2f}s audio)

""" # Format error table with summary of problematic sounds if error_table_rows: # Group errors by phoneme to show which sounds have issues phoneme_errors = {} for row in error_table_rows: phoneme = row['phoneme'] if phoneme not in phoneme_errors: phoneme_errors[phoneme] = { 'count': 0, 'types': set(), 'severity': 'low', 'examples': [] } phoneme_errors[phoneme]['count'] += 1 phoneme_errors[phoneme]['types'].add(row['error_type']) if row['severity'] in ['high', 'medium']: phoneme_errors[phoneme]['severity'] = row['severity'] if len(phoneme_errors[phoneme]['examples']) < 2: phoneme_errors[phoneme]['examples'].append(row) # Create summary section problematic_sounds = sorted(phoneme_errors.keys()) summary_html = f"""

⚠️ Problematic Sounds Detected

{len(problematic_sounds)} sound(s) with issues: {', '.join([f'/{p}/' for p in problematic_sounds[:10]])} {f'(+{len(problematic_sounds) - 10} more)' if len(problematic_sounds) > 10 else ''}

""" for phoneme in problematic_sounds[:10]: error_info = phoneme_errors[phoneme] severity_color = 'red' if error_info['severity'] == 'high' else 'orange' if error_info['severity'] == 'medium' else '#666' summary_html += f"""
/{phoneme}/
{error_info['count']} error(s)
Types: {', '.join(error_info['types'])}
""" summary_html += """
""" # Create detailed error table error_table_html = summary_html + """

📋 Detailed Error Report

""" for row in error_table_rows[:20]: # Limit to first 20 errors severity_bg = { 'high': '#ffebee', 'medium': '#fff3e0', 'low': '#f3e5f5', 'none': '#e8f5e9' }.get(row['severity'], '#f5f5f5') error_table_html += f""" """ error_table_html += """
Sound Time Error Type Wrong Sound Severity Therapy Recommendation
/{row['phoneme']}/ {row['time']} {row['error_type'].upper()} {f"/{row['wrong_sound']}/" if row['wrong_sound'] != 'N/A' else 'N/A'} {row['severity'].upper()} {row['therapy']}
""" if len(error_table_rows) > 20: error_table_html += f"

📊 Showing first 20 of {len(error_table_rows)} total errors detected

" else: error_table_html = """

✅ No Errors Detected

All sounds/phonemes were produced correctly!
Great job! 🎉

""" # Create JSON output with errors json_output = { "status": "success", "fluency_metrics": { "mean_fluency": fluency_percentage / 100.0, "fluency_percentage": fluency_percentage, "fluent_frames_ratio": fluent_frames / result.num_frames if result.num_frames > 0 else 0.0, "fluent_frames_percentage": fluent_frames_percentage, "stutter_probability": mean_fluency_stutter }, "articulation_results": { "total_frames": result.num_frames, "dominant_class": articulation_class, "dominant_label": articulation_label, "class_distribution": class_counts }, "confidence": avg_confidence, "confidence_percentage": confidence_percentage, "processing_time_ms": processing_time_ms, "error_count": len(errors), "errors": [ { "phoneme": err[2].phoneme, "time": err[1], "error_type": err[2].error_type.value, "wrong_sound": err[2].wrong_sound, "severity": error_mapper.get_severity_level(err[2].severity).value if error_mapper else "unknown", "therapy": err[2].therapy } for err in errors[:20] ] if errors else [], "frame_predictions": [ { "time": fp.time, "fluency_prob": fp.fluency_prob, "fluency_label": fp.fluency_label, "articulation_class": fp.articulation_class, "articulation_label": fp.articulation_label, "confidence": fp.confidence, "phoneme": frame_phonemes[i] if i < len(frame_phonemes) else '' } for i, fp in enumerate(result.frame_predictions[:20]) # First 20 frames for preview ] } logger.info(f"✅ Analysis complete: fluency={fluency_percentage:.1f}%, " f"confidence={confidence_percentage:.1f}%, " f"time={processing_time_ms:.0f}ms") # Cleanup temporary file if created if audio_input is not None and audio_path and os.path.exists(audio_path): try: os.unlink(audio_path) logger.debug(f"Cleaned up temporary file: {audio_path}") except Exception as e: logger.warning(f"Could not clean up temp file: {e}") return ( fluency_html, articulation_text, confidence_html, processing_time_html, error_table_html, json_output ) except Exception as e: logger.error(f"❌ Analysis failed: {e}", exc_info=True) # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:error","message":"Exception caught in analyze_speech","data":{"error_type":type(e).__name__,"error_msg":str(e)},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion error_html = f"

❌ Error: {str(e)}

" error_table_html = "

No error details available

" return ( error_html, f"Error: {str(e)}", "N/A", "N/A", error_table_html, {"error": str(e), "status": "error"} ) def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.Blocks: """ Create the Gradio interface for speech pathology diagnosis. Args: gradio_config: Gradio configuration. Uses default if None. Returns: Gradio Blocks interface """ config = gradio_config or default_gradio_config logger.info(f"Creating Gradio interface: {config.title}") # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .output-box { border: 2px solid #e0e0e0; border-radius: 8px; padding: 15px; margin: 10px 0; background-color: #f9f9f9; } """ with gr.Blocks(title=config.title, css=custom_css, theme=config.theme) as interface: gr.Markdown(f""" # 🎤 {config.title} {config.description} **Features:** - 📁 Upload audio files (WAV, MP3, FLAC, M4A) - 🎙️ Record audio directly from microphone - 📊 Real-time fluency and articulation analysis - ⚡ Phone-level analysis (20ms frames) """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📥 Audio Input") audio_file = gr.Audio( type="filepath", label="Upload Audio File", sources=["upload"], format="wav" ) audio_mic = gr.Audio( type="numpy", label="Record from Microphone", sources=["microphone"], format="wav" ) expected_text_input = gr.Textbox( label="Expected Text (Optional)", placeholder="Enter the expected text/transcript for phoneme mapping", lines=2, info="Provide the expected text to enable phoneme-level error detection" ) analyze_btn = gr.Button( "🔍 Analyze Speech", variant="primary", size="lg" ) gr.Markdown(""" **Instructions:** 1. Upload an audio file OR record from microphone 2. Click "Analyze Speech" button 3. View results below """) with gr.Column(scale=1): gr.Markdown("### 📊 Analysis Results") fluency_output = gr.HTML( label="Fluency Score", elem_classes=["output-box"] ) articulation_output = gr.Textbox( label="Articulation Issues", lines=8, interactive=False, elem_classes=["output-box"] ) with gr.Row(): confidence_output = gr.HTML( label="Confidence", elem_classes=["output-box"] ) processing_time_output = gr.HTML( label="Processing Info", elem_classes=["output-box"] ) error_table_output = gr.HTML( label="Error Details", elem_classes=["output-box"] ) json_output = gr.JSON( label="Detailed Results (JSON)", elem_classes=["output-box"] ) # Set up event handlers analyze_btn.click( fn=analyze_speech, inputs=[audio_mic, audio_file, expected_text_input], outputs=[ fluency_output, articulation_output, confidence_output, processing_time_output, error_table_output, json_output ] ) # Examples if provided if config.examples: gr.Examples( examples=config.examples, inputs=audio_file, label="Example Audio Files" ) gr.Markdown(""" --- **About:** - Uses Wav2Vec2-XLSR-53 for speech analysis - Phone-level granularity (20ms frames) - Detects fluency issues and articulation problems - Processing time: <200ms per chunk """) return interface def launch_gradio_interface( gradio_config: Optional[GradioConfig] = None, share: Optional[bool] = None ) -> None: """ Launch the Gradio interface standalone. Args: gradio_config: Gradio configuration share: Whether to create public link (overrides config) """ config = gradio_config or default_gradio_config share = share if share is not None else config.share interface = create_gradio_interface(config) logger.info(f"🚀 Launching Gradio interface on port {config.port}") interface.launch( server_name="0.0.0.0", server_port=config.port, share=share ) if __name__ == "__main__": import logging logging.basicConfig(level=logging.INFO) launch_gradio_interface()