zlaqa-version-c-ai-enginee / ui /gradio_interface.py
anfastech's picture
New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
1cd6149
"""
Gradio Interface for Speech Pathology Diagnosis
This module provides a user-friendly web interface for speech pathology analysis
using Gradio. Supports both file upload and microphone input with structured
output display.
"""
import logging
import time
import tempfile
import os
from pathlib import Path
from typing import Tuple, Optional, Dict, Any
import numpy as np
import gradio as gr
from diagnosis.ai_engine.model_loader import get_inference_pipeline
from api.routes import get_phoneme_mapper, get_error_mapper
from models.error_taxonomy import ErrorType, SeverityLevel
from config import GradioConfig, default_gradio_config
logger = logging.getLogger(__name__)
# Global inference pipeline instance
_inference_pipeline = None
def get_inference_pipeline_instance():
"""Get or initialize the inference pipeline singleton."""
global _inference_pipeline
if _inference_pipeline is None:
try:
_inference_pipeline = get_inference_pipeline()
logger.info("βœ… Inference pipeline loaded for Gradio interface")
except Exception as e:
logger.error(f"❌ Failed to load inference pipeline: {e}", exc_info=True)
raise
return _inference_pipeline
def format_articulation_issues(articulation_scores: list) -> str:
"""
Format articulation issues from prediction results.
Args:
articulation_scores: List of articulation predictions
Returns:
Formatted string describing articulation issues
"""
if not articulation_scores:
return "No articulation data available"
# Count occurrences of each articulation type
articulation_counts = {
"normal": 0,
"substitution": 0,
"omission": 0,
"distortion": 0
}
for score in articulation_scores:
class_name = score.get("class_name", "normal")
if class_name in articulation_counts:
articulation_counts[class_name] += 1
total_frames = len(articulation_scores)
if total_frames == 0:
return "No frames analyzed"
# Calculate percentages
issues = []
for art_type, count in articulation_counts.items():
if art_type != "normal" and count > 0:
percentage = (count / total_frames) * 100
issues.append(f"{art_type.capitalize()}: {percentage:.1f}% ({count}/{total_frames} frames)")
if not issues:
return "βœ… No articulation issues detected - Normal articulation"
else:
return "⚠️ Articulation Issues Detected:\n" + "\n".join(f" β€’ {issue}" for issue in issues)
def analyze_speech(
audio_input: Optional[Tuple[int, np.ndarray]],
audio_file: Optional[str],
expected_text: Optional[str] = None
) -> Tuple[str, str, str, str, str, Dict[str, Any]]:
"""
Analyze speech audio for fluency and articulation issues.
Args:
audio_input: Tuple of (sample_rate, audio_array) from microphone
audio_file: Path to uploaded audio file
Returns:
Tuple of (fluency_score_html, articulation_issues, confidence_html, processing_time_html, json_output)
"""
start_time = time.time()
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:98","message":"analyze_speech entry","data":{"has_audio_file":audio_file is not None,"has_audio_input":audio_input is not None},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
try:
# Get inference pipeline
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:102","message":"Getting inference pipeline","data":{},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
pipeline = get_inference_pipeline_instance()
# Determine audio source
audio_path = None
if audio_file is not None and audio_file != "":
audio_path = audio_file
logger.info(f"Processing uploaded file: {audio_path}")
elif audio_input is not None:
# Save microphone input to temporary file
sample_rate, audio_array = audio_input
logger.info(f"Processing microphone input: {len(audio_array)} samples at {sample_rate}Hz")
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_path = temp_file.name
temp_file.close()
try:
# Save audio to temporary file
import soundfile as sf
sf.write(temp_path, audio_array, sample_rate)
audio_path = temp_path
logger.info(f"Saved microphone input to: {temp_path}")
except Exception as e:
logger.error(f"Failed to save microphone audio: {e}")
raise ValueError(f"Cannot process microphone audio: {e}")
else:
return (
"<p style='color: red;'>❌ Error: No audio input provided</p>",
"No audio provided",
"N/A",
"N/A",
{"error": "No audio input provided"}
)
# Run batch prediction
logger.info(f"Running batch prediction on: {audio_path}")
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:137","message":"Before predict_batch call","data":{"audio_path":audio_path},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
result = pipeline.predict_batch(audio_path, return_timestamps=True, apply_smoothing=True)
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:139","message":"After predict_batch call","data":{"success":True,"num_frames":result.num_frames},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
# Get phoneme and error mappers
phoneme_mapper = get_phoneme_mapper()
error_mapper = get_error_mapper()
# Map phonemes to frames if text provided
frame_phonemes = []
if expected_text and phoneme_mapper:
try:
frame_phonemes = phoneme_mapper.map_text_to_frames(
expected_text,
num_frames=result.num_frames,
audio_duration=result.duration
)
logger.info(f"βœ… Mapped {len(frame_phonemes)} phonemes to frames")
except Exception as e:
logger.warning(f"⚠️ Phoneme mapping failed: {e}")
frame_phonemes = [''] * result.num_frames
else:
frame_phonemes = [''] * result.num_frames
# Process errors with error mapper
errors = []
error_table_rows = []
for i, frame_pred in enumerate(result.frame_predictions):
phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
# Map classifier output to error detail (8-class system)
class_id = frame_pred.articulation_class
if frame_pred.fluency_label == 'stutter':
class_id += 4 # Add 4 for stutter classes (4-7)
# Get error detail
if error_mapper:
try:
error_detail = error_mapper.map_classifier_output(
class_id=class_id,
confidence=frame_pred.confidence,
phoneme=phoneme if phoneme else 'unknown',
fluency_label=frame_pred.fluency_label
)
if error_detail.error_type != ErrorType.NORMAL:
errors.append((i, frame_pred.time, error_detail))
# Add to error table
severity_level = error_mapper.get_severity_level(error_detail.severity)
severity_color = {
SeverityLevel.NONE: "green",
SeverityLevel.LOW: "orange",
SeverityLevel.MEDIUM: "orange",
SeverityLevel.HIGH: "red"
}.get(severity_level, "gray")
error_table_rows.append({
"phoneme": error_detail.phoneme,
"time": f"{frame_pred.time:.2f}s",
"error_type": error_detail.error_type.value,
"wrong_sound": error_detail.wrong_sound or "N/A",
"severity": severity_level.value,
"severity_color": severity_color,
"therapy": error_detail.therapy[:80] + "..." if len(error_detail.therapy) > 80 else error_detail.therapy
})
except Exception as e:
logger.warning(f"Error mapping failed for frame {i}: {e}")
# Calculate processing time
processing_time_ms = (time.time() - start_time) * 1000
# Extract metrics from new PhoneLevelResult format
aggregate = result.aggregate
mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage
# Count fluent frames
fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
fluent_frames_percentage = (fluent_frames / result.num_frames * 100) if result.num_frames > 0 else 0.0
# Format fluency score with color coding
if fluency_percentage >= 80:
fluency_color = "green"
fluency_emoji = "βœ…"
elif fluency_percentage >= 60:
fluency_color = "orange"
fluency_emoji = "⚠️"
else:
fluency_color = "red"
fluency_emoji = "❌"
fluency_html = f"""
<div style='text-align: center; padding: 20px;'>
<h2 style='color: {fluency_color}; font-size: 48px; margin: 10px 0;'>
{fluency_emoji} {fluency_percentage:.1f}%
</h2>
<p style='color: #666; font-size: 14px;'>
Mean Fluency Score<br/>
Fluent Frames: {fluent_frames_percentage:.1f}%
</p>
</div>
"""
# Format articulation issues
articulation_class = aggregate.get("articulation_class", 0)
articulation_label = aggregate.get("articulation_label", "normal")
articulation_text = f"**Dominant Class:** {articulation_label.capitalize()}\n\n"
articulation_text += f"**Frame Breakdown:**\n"
class_counts = {}
for fp in result.frame_predictions:
label = fp.articulation_label
class_counts[label] = class_counts.get(label, 0) + 1
for label, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
percentage = (count / result.num_frames * 100) if result.num_frames > 0 else 0.0
articulation_text += f"- {label.capitalize()}: {count} frames ({percentage:.1f}%)\n"
# Calculate average confidence
avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
confidence_percentage = avg_confidence * 100
confidence_html = f"""
<div style='text-align: center; padding: 10px;'>
<h3 style='color: #2196F3; font-size: 32px; margin: 5px 0;'>
{confidence_percentage:.1f}%
</h3>
<p style='color: #666; font-size: 12px;'>Overall Confidence</p>
</div>
"""
# Format processing time
processing_time_html = f"""
<div style='text-align: center; padding: 10px;'>
<p style='color: #666; font-size: 14px;'>
⏱️ Processing Time: <strong>{processing_time_ms:.0f}ms</strong>
</p>
<p style='color: #999; font-size: 12px;'>
Analyzed {result.num_frames} frames ({result.duration:.2f}s audio)
</p>
</div>
"""
# Format error table with summary of problematic sounds
if error_table_rows:
# Group errors by phoneme to show which sounds have issues
phoneme_errors = {}
for row in error_table_rows:
phoneme = row['phoneme']
if phoneme not in phoneme_errors:
phoneme_errors[phoneme] = {
'count': 0,
'types': set(),
'severity': 'low',
'examples': []
}
phoneme_errors[phoneme]['count'] += 1
phoneme_errors[phoneme]['types'].add(row['error_type'])
if row['severity'] in ['high', 'medium']:
phoneme_errors[phoneme]['severity'] = row['severity']
if len(phoneme_errors[phoneme]['examples']) < 2:
phoneme_errors[phoneme]['examples'].append(row)
# Create summary section
problematic_sounds = sorted(phoneme_errors.keys())
summary_html = f"""
<div style='background-color: #fff3cd; border: 2px solid #ffc107; border-radius: 8px; padding: 15px; margin-bottom: 20px;'>
<h3 style='color: #856404; margin-top: 0;'>⚠️ Problematic Sounds Detected</h3>
<p style='color: #856404; font-size: 14px; margin-bottom: 10px;'>
<strong>{len(problematic_sounds)} sound(s) with issues:</strong> {', '.join([f'<strong style="color: red;">/{p}/</strong>' for p in problematic_sounds[:10]])}
{f'<span style="color: #666;">(+{len(problematic_sounds) - 10} more)</span>' if len(problematic_sounds) > 10 else ''}
</p>
<div style='display: flex; flex-wrap: wrap; gap: 10px;'>
"""
for phoneme in problematic_sounds[:10]:
error_info = phoneme_errors[phoneme]
severity_color = 'red' if error_info['severity'] == 'high' else 'orange' if error_info['severity'] == 'medium' else '#666'
summary_html += f"""
<div style='background-color: white; border: 1px solid {severity_color}; border-radius: 4px; padding: 8px; min-width: 120px;'>
<strong style='color: {severity_color}; font-size: 18px;'>/{phoneme}/</strong>
<div style='font-size: 12px; color: #666;'>
{error_info['count']} error(s)<br/>
Types: {', '.join(error_info['types'])}
</div>
</div>
"""
summary_html += """
</div>
</div>
"""
# Create detailed error table
error_table_html = summary_html + """
<h4 style='color: #333; margin-top: 20px;'>πŸ“‹ Detailed Error Report</h4>
<table style='width: 100%; border-collapse: collapse; margin: 10px 0; font-size: 13px;'>
<thead>
<tr style='background-color: #f0f0f0;'>
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Sound</th>
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Time</th>
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Error Type</th>
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Wrong Sound</th>
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Severity</th>
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Therapy Recommendation</th>
</tr>
</thead>
<tbody>
"""
for row in error_table_rows[:20]: # Limit to first 20 errors
severity_bg = {
'high': '#ffebee',
'medium': '#fff3e0',
'low': '#f3e5f5',
'none': '#e8f5e9'
}.get(row['severity'], '#f5f5f5')
error_table_html += f"""
<tr style='background-color: {severity_bg};'>
<td style='padding: 10px; border: 1px solid #ddd;'>
<strong style='color: {row['severity_color']}; font-size: 16px;'>/{row['phoneme']}/</strong>
</td>
<td style='padding: 10px; border: 1px solid #ddd;'>{row['time']}</td>
<td style='padding: 10px; border: 1px solid #ddd;'>
<span style='background-color: {row['severity_color']}; color: white; padding: 3px 8px; border-radius: 3px; font-size: 11px;'>
{row['error_type'].upper()}
</span>
</td>
<td style='padding: 10px; border: 1px solid #ddd;'>
{f"<strong style='color: red;'>/{row['wrong_sound']}/</strong>" if row['wrong_sound'] != 'N/A' else '<span style="color: #999;">N/A</span>'}
</td>
<td style='padding: 10px; border: 1px solid #ddd;'>
<strong style='color: {row['severity_color']};'>{row['severity'].upper()}</strong>
</td>
<td style='padding: 10px; border: 1px solid #ddd; font-size: 12px;'>{row['therapy']}</td>
</tr>
"""
error_table_html += """
</tbody>
</table>
"""
if len(error_table_rows) > 20:
error_table_html += f"<p style='color: #666; font-size: 12px; margin-top: 10px;'>πŸ“Š Showing first 20 of <strong>{len(error_table_rows)}</strong> total errors detected</p>"
else:
error_table_html = """
<div style='background-color: #d4edda; border: 2px solid #28a745; border-radius: 8px; padding: 20px; text-align: center;'>
<h3 style='color: #155724; margin-top: 0;'>βœ… No Errors Detected</h3>
<p style='color: #155724; font-size: 16px;'>
All sounds/phonemes were produced correctly!<br/>
<span style='font-size: 14px; color: #666;'>Great job! πŸŽ‰</span>
</p>
</div>
"""
# Create JSON output with errors
json_output = {
"status": "success",
"fluency_metrics": {
"mean_fluency": fluency_percentage / 100.0,
"fluency_percentage": fluency_percentage,
"fluent_frames_ratio": fluent_frames / result.num_frames if result.num_frames > 0 else 0.0,
"fluent_frames_percentage": fluent_frames_percentage,
"stutter_probability": mean_fluency_stutter
},
"articulation_results": {
"total_frames": result.num_frames,
"dominant_class": articulation_class,
"dominant_label": articulation_label,
"class_distribution": class_counts
},
"confidence": avg_confidence,
"confidence_percentage": confidence_percentage,
"processing_time_ms": processing_time_ms,
"error_count": len(errors),
"errors": [
{
"phoneme": err[2].phoneme,
"time": err[1],
"error_type": err[2].error_type.value,
"wrong_sound": err[2].wrong_sound,
"severity": error_mapper.get_severity_level(err[2].severity).value if error_mapper else "unknown",
"therapy": err[2].therapy
}
for err in errors[:20]
] if errors else [],
"frame_predictions": [
{
"time": fp.time,
"fluency_prob": fp.fluency_prob,
"fluency_label": fp.fluency_label,
"articulation_class": fp.articulation_class,
"articulation_label": fp.articulation_label,
"confidence": fp.confidence,
"phoneme": frame_phonemes[i] if i < len(frame_phonemes) else ''
}
for i, fp in enumerate(result.frame_predictions[:20]) # First 20 frames for preview
]
}
logger.info(f"βœ… Analysis complete: fluency={fluency_percentage:.1f}%, "
f"confidence={confidence_percentage:.1f}%, "
f"time={processing_time_ms:.0f}ms")
# Cleanup temporary file if created
if audio_input is not None and audio_path and os.path.exists(audio_path):
try:
os.unlink(audio_path)
logger.debug(f"Cleaned up temporary file: {audio_path}")
except Exception as e:
logger.warning(f"Could not clean up temp file: {e}")
return (
fluency_html,
articulation_text,
confidence_html,
processing_time_html,
error_table_html,
json_output
)
except Exception as e:
logger.error(f"❌ Analysis failed: {e}", exc_info=True)
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:error","message":"Exception caught in analyze_speech","data":{"error_type":type(e).__name__,"error_msg":str(e)},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
error_html = f"<p style='color: red;'>❌ Error: {str(e)}</p>"
error_table_html = "<p style='color: #999;'>No error details available</p>"
return (
error_html,
f"Error: {str(e)}",
"N/A",
"N/A",
error_table_html,
{"error": str(e), "status": "error"}
)
def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.Blocks:
"""
Create the Gradio interface for speech pathology diagnosis.
Args:
gradio_config: Gradio configuration. Uses default if None.
Returns:
Gradio Blocks interface
"""
config = gradio_config or default_gradio_config
logger.info(f"Creating Gradio interface: {config.title}")
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.output-box {
border: 2px solid #e0e0e0;
border-radius: 8px;
padding: 15px;
margin: 10px 0;
background-color: #f9f9f9;
}
"""
with gr.Blocks(title=config.title, css=custom_css, theme=config.theme) as interface:
gr.Markdown(f"""
# 🎀 {config.title}
{config.description}
**Features:**
- πŸ“ Upload audio files (WAV, MP3, FLAC, M4A)
- πŸŽ™οΈ Record audio directly from microphone
- πŸ“Š Real-time fluency and articulation analysis
- ⚑ Phone-level analysis (20ms frames)
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“₯ Audio Input")
audio_file = gr.Audio(
type="filepath",
label="Upload Audio File",
sources=["upload"],
format="wav"
)
audio_mic = gr.Audio(
type="numpy",
label="Record from Microphone",
sources=["microphone"],
format="wav"
)
expected_text_input = gr.Textbox(
label="Expected Text (Optional)",
placeholder="Enter the expected text/transcript for phoneme mapping",
lines=2,
info="Provide the expected text to enable phoneme-level error detection"
)
analyze_btn = gr.Button(
"πŸ” Analyze Speech",
variant="primary",
size="lg"
)
gr.Markdown("""
**Instructions:**
1. Upload an audio file OR record from microphone
2. Click "Analyze Speech" button
3. View results below
""")
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š Analysis Results")
fluency_output = gr.HTML(
label="Fluency Score",
elem_classes=["output-box"]
)
articulation_output = gr.Textbox(
label="Articulation Issues",
lines=8,
interactive=False,
elem_classes=["output-box"]
)
with gr.Row():
confidence_output = gr.HTML(
label="Confidence",
elem_classes=["output-box"]
)
processing_time_output = gr.HTML(
label="Processing Info",
elem_classes=["output-box"]
)
error_table_output = gr.HTML(
label="Error Details",
elem_classes=["output-box"]
)
json_output = gr.JSON(
label="Detailed Results (JSON)",
elem_classes=["output-box"]
)
# Set up event handlers
analyze_btn.click(
fn=analyze_speech,
inputs=[audio_mic, audio_file, expected_text_input],
outputs=[
fluency_output,
articulation_output,
confidence_output,
processing_time_output,
error_table_output,
json_output
]
)
# Examples if provided
if config.examples:
gr.Examples(
examples=config.examples,
inputs=audio_file,
label="Example Audio Files"
)
gr.Markdown("""
---
**About:**
- Uses Wav2Vec2-XLSR-53 for speech analysis
- Phone-level granularity (20ms frames)
- Detects fluency issues and articulation problems
- Processing time: <200ms per chunk
""")
return interface
def launch_gradio_interface(
gradio_config: Optional[GradioConfig] = None,
share: Optional[bool] = None
) -> None:
"""
Launch the Gradio interface standalone.
Args:
gradio_config: Gradio configuration
share: Whether to create public link (overrides config)
"""
config = gradio_config or default_gradio_config
share = share if share is not None else config.share
interface = create_gradio_interface(config)
logger.info(f"πŸš€ Launching Gradio interface on port {config.port}")
interface.launch(
server_name="0.0.0.0",
server_port=config.port,
share=share
)
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.INFO)
launch_gradio_interface()