Ellie5757575757's picture
Update app.py
f84d4de verified
raw
history blame
17.5 kB
import gradio as gr
import json
import os
import tempfile
import logging
import traceback
from pathlib import Path
# Import your pipeline modules
try:
from utils_audio import convert_to_wav
from to_cha import to_cha_from_wav
from cha_json import cha_to_json_file
from output import predict_from_chajson
except ImportError as e:
logging.error(f"Import error: {e}")
# Fallback imports or error handling
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
MODEL_DIR = "./adaptive_aphasia_model" # Path to your trained model
SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
def run_complete_pipeline(audio_file_path: str) -> dict:
"""
Complete pipeline: Audio β†’ WAV β†’ CHA β†’ JSON β†’ Model Prediction
"""
try:
logger.info(f"Starting pipeline for: {audio_file_path}")
# Step 1: Convert to WAV
logger.info("Step 1: Converting audio to WAV...")
wav_path = convert_to_wav(audio_file_path, sr=16000, mono=True)
logger.info(f"WAV conversion completed: {wav_path}")
# Step 2: Generate CHA file using Batchalign
logger.info("Step 2: Generating CHA file...")
cha_path = to_cha_from_wav(wav_path, lang="eng")
logger.info(f"CHA generation completed: {cha_path}")
# Step 3: Convert CHA to JSON
logger.info("Step 3: Converting CHA to JSON...")
chajson_path, json_data = cha_to_json_file(cha_path)
logger.info(f"JSON conversion completed: {chajson_path}")
# Step 4: Run aphasia classification
logger.info("Step 4: Running aphasia classification...")
results = predict_from_chajson(MODEL_DIR, chajson_path, output_file=None)
logger.info("Classification completed")
# Cleanup temporary files
try:
os.unlink(wav_path)
os.unlink(cha_path)
os.unlink(chajson_path)
except Exception as cleanup_error:
logger.warning(f"Cleanup error: {cleanup_error}")
return {
"success": True,
"results": results,
"message": "Pipeline completed successfully"
}
except Exception as e:
logger.error(f"Pipeline error: {str(e)}")
logger.error(traceback.format_exc())
return {
"success": False,
"error": str(e),
"message": f"Pipeline failed: {str(e)}"
}
def process_audio_input(audio_file):
"""
Process audio file and return formatted results
"""
try:
if audio_file is None:
return (
"❌ Error: No audio file uploaded",
"",
"",
"",
""
)
# Check file format
file_path = audio_file
if isinstance(audio_file, str):
file_path = audio_file
else:
# Handle Gradio file object
file_path = audio_file.name if hasattr(audio_file, 'name') else str(audio_file)
file_ext = Path(file_path).suffix.lower()
if file_ext not in SUPPORTED_AUDIO_FORMATS:
return (
f"❌ Error: Unsupported file format {file_ext}",
f"Supported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}",
"",
"",
""
)
# Run the complete pipeline
pipeline_result = run_complete_pipeline(file_path)
if not pipeline_result["success"]:
return (
f"❌ Pipeline Error: {pipeline_result['message']}",
pipeline_result.get('error', ''),
"",
"",
""
)
# Extract results
results = pipeline_result["results"]
# Format main prediction
if "predictions" in results and len(results["predictions"]) > 0:
first_pred = results["predictions"][0]
if "error" in first_pred:
return (
f"❌ Classification Error: {first_pred['error']}",
"",
"",
"",
""
)
# Main prediction
predicted_class = first_pred["prediction"]["predicted_class"]
confidence = first_pred["prediction"]["confidence_percentage"]
class_description = first_pred["class_description"]["name"]
main_result = f"🧠 **Predicted Aphasia Type:** {predicted_class}\n"
main_result += f"πŸ“Š **Confidence:** {confidence}\n"
main_result += f"πŸ“‹ **Description:** {class_description}"
# Detailed analysis
features = first_pred["class_description"].get("features", [])
detailed_analysis = f"**Key Features:**\n"
for feature in features:
detailed_analysis += f"β€’ {feature}\n"
detailed_analysis += f"\n**Clinical Description:**\n"
detailed_analysis += first_pred["class_description"].get("description", "No description available")
# Additional metrics
additional_info = first_pred["additional_predictions"]
severity_level = additional_info["predicted_severity_level"]
fluency_score = additional_info["fluency_score"]
fluency_rating = additional_info["fluency_rating"]
additional_metrics = f"**Severity Level:** {severity_level}/3\n"
additional_metrics += f"**Fluency Score:** {fluency_score:.3f} ({fluency_rating})\n"
# Probability distribution (top 3)
prob_dist = first_pred["probability_distribution"]
top_3 = list(prob_dist.items())[:3]
probability_breakdown = "**Top 3 Classifications:**\n"
for i, (aphasia_type, info) in enumerate(top_3, 1):
probability_breakdown += f"{i}. {aphasia_type}: {info['percentage']}\n"
# Summary statistics
summary = results.get("summary", {})
summary_text = f"**Processing Summary:**\n"
summary_text += f"β€’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}\n"
summary_text += f"β€’ Average confidence: {summary.get('average_confidence', 'N/A')}\n"
summary_text += f"β€’ Average fluency: {summary.get('average_fluency_score', 'N/A')}\n"
return (
main_result,
detailed_analysis,
additional_metrics,
probability_breakdown,
summary_text
)
else:
return (
"❌ No predictions generated",
"The audio file may not contain analyzable speech",
"",
"",
""
)
except Exception as e:
logger.error(f"Processing error: {str(e)}")
logger.error(traceback.format_exc())
return (
f"❌ Processing Error: {str(e)}",
"Please check the logs for more details",
"",
"",
""
)
def process_text_input(text_input):
"""
Process text input directly (fallback option)
"""
try:
if not text_input or not text_input.strip():
return (
"❌ Error: Please enter some text for analysis",
"",
"",
"",
""
)
# Create a simple JSON structure for text-only input
temp_json = {
"sentences": [{
"sentence_id": "S1",
"aphasia_type": "UNKNOWN",
"dialogues": [{
"INV": [],
"PAR": [{
"tokens": text_input.split(),
"word_pos_ids": [0] * len(text_input.split()),
"word_grammar_ids": [[0, 0, 0]] * len(text_input.split()),
"word_durations": [0.0] * len(text_input.split()),
"utterance_text": text_input
}]
}]
}],
"text_all": text_input
}
# Save to temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(temp_json, f, ensure_ascii=False, indent=2)
temp_json_path = f.name
# Run prediction
results = predict_from_chajson(MODEL_DIR, temp_json_path, output_file=None)
# Cleanup
try:
os.unlink(temp_json_path)
except:
pass
# Format results (similar to audio processing)
if "predictions" in results and len(results["predictions"]) > 0:
first_pred = results["predictions"][0]
predicted_class = first_pred["prediction"]["predicted_class"]
confidence = first_pred["prediction"]["confidence_percentage"]
return (
f"🧠 **Predicted:** {predicted_class} ({confidence})",
first_pred["class_description"]["description"],
f"Severity: {first_pred['additional_predictions']['predicted_severity_level']}/3",
f"Fluency: {first_pred['additional_predictions']['fluency_rating']}",
"Text-based analysis completed"
)
else:
return (
"❌ No predictions generated",
"",
"",
"",
""
)
except Exception as e:
logger.error(f"Text processing error: {str(e)}")
return (
f"❌ Error: {str(e)}",
"",
"",
"",
""
)
# Create Gradio interface
def create_interface():
"""Create the main Gradio interface"""
with gr.Blocks(
title="Advanced Aphasia Classification System",
theme=gr.themes.Soft(),
css="""
.main-header { text-align: center; margin-bottom: 2rem; }
.upload-section { border: 2px dashed #ccc; padding: 2rem; border-radius: 10px; }
.results-section { margin-top: 2rem; }
"""
) as demo:
# Header
gr.HTML("""
<div class="main-header">
<h1>🧠 Advanced Aphasia Classification System</h1>
<p>Upload audio files (MP3, MP4, WAV) or enter text to analyze speech patterns and classify aphasia types</p>
</div>
""")
with gr.Tabs():
# Audio Input Tab
with gr.TabItem("🎡 Audio Analysis", id="audio_tab"):
gr.Markdown("### Upload Audio File")
gr.Markdown("Supported formats: MP3, MP4, WAV, M4A, FLAC, OGG")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.File(
label="Upload Audio File",
file_types=["audio"],
type="filepath"
)
process_audio_btn = gr.Button(
"πŸ” Analyze Audio",
variant="primary",
size="lg"
)
gr.Markdown("**Note:** Processing may take 1-3 minutes depending on audio length")
# Results section for audio
with gr.Column(scale=2, visible=True) as audio_results:
gr.Markdown("### πŸ“Š Analysis Results")
audio_main_result = gr.Textbox(
label="🎯 Primary Classification",
lines=3,
interactive=False
)
with gr.Row():
audio_detailed = gr.Textbox(
label="πŸ“‹ Detailed Analysis",
lines=6,
interactive=False
)
audio_metrics = gr.Textbox(
label="πŸ“ˆ Additional Metrics",
lines=6,
interactive=False
)
with gr.Row():
audio_probabilities = gr.Textbox(
label="πŸ“Š Probability Breakdown",
lines=4,
interactive=False
)
audio_summary = gr.Textbox(
label="πŸ“ Processing Summary",
lines=4,
interactive=False
)
# Text Input Tab (Fallback)
with gr.TabItem("πŸ“ Text Analysis", id="text_tab"):
gr.Markdown("### Direct Text Input")
gr.Markdown("Enter speech transcription or text for analysis (fallback option)")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter speech transcription or text for analysis...",
lines=5
)
process_text_btn = gr.Button(
"πŸ” Analyze Text",
variant="secondary",
size="lg"
)
# Results section for text
with gr.Column() as text_results:
gr.Markdown("### πŸ“Š Analysis Results")
text_main_result = gr.Textbox(
label="🎯 Primary Classification",
lines=2,
interactive=False
)
with gr.Row():
text_detailed = gr.Textbox(
label="πŸ“‹ Clinical Description",
lines=4,
interactive=False
)
text_metrics = gr.Textbox(
label="πŸ“ˆ Metrics",
lines=4,
interactive=False
)
with gr.Row():
text_probabilities = gr.Textbox(
label="πŸ“Š Assessment",
lines=2,
interactive=False
)
text_summary = gr.Textbox(
label="πŸ“ Status",
lines=2,
interactive=False
)
# Event handlers
process_audio_btn.click(
fn=process_audio_input,
inputs=[audio_input],
outputs=[
audio_main_result,
audio_detailed,
audio_metrics,
audio_probabilities,
audio_summary
]
)
process_text_btn.click(
fn=process_text_input,
inputs=[text_input],
outputs=[
text_main_result,
text_detailed,
text_metrics,
text_probabilities,
text_summary
]
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #eee;">
<p><strong>About:</strong> This system uses advanced NLP and acoustic analysis to classify different types of aphasia from speech samples.</p>
<p><em>For research and clinical assessment purposes.</em></p>
</div>
""")
return demo
# Launch the application
if __name__ == "__main__":
try:
logger.info("Starting Aphasia Classification System...")
# Check if model directory exists
if not os.path.exists(MODEL_DIR):
logger.error(f"Model directory not found: {MODEL_DIR}")
print(f"❌ Error: Model directory not found: {MODEL_DIR}")
print("Please ensure your trained model is in the correct directory.")
# Create and launch interface
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)
except Exception as e:
logger.error(f"Failed to launch app: {e}")
logger.error(traceback.format_exc())
print(f"❌ Application startup failed: {e}")