Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import os | |
| import tempfile | |
| import logging | |
| import traceback | |
| from pathlib import Path | |
| # Import your pipeline modules | |
| try: | |
| from utils_audio import convert_to_wav | |
| from to_cha import to_cha_from_wav | |
| from cha_json import cha_to_json_file | |
| from output import predict_from_chajson | |
| except ImportError as e: | |
| logging.error(f"Import error: {e}") | |
| # Fallback imports or error handling | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_DIR = "./adaptive_aphasia_model" # Path to your trained model | |
| SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"] | |
| def run_complete_pipeline(audio_file_path: str) -> dict: | |
| """ | |
| Complete pipeline: Audio β WAV β CHA β JSON β Model Prediction | |
| """ | |
| try: | |
| logger.info(f"Starting pipeline for: {audio_file_path}") | |
| # Step 1: Convert to WAV | |
| logger.info("Step 1: Converting audio to WAV...") | |
| wav_path = convert_to_wav(audio_file_path, sr=16000, mono=True) | |
| logger.info(f"WAV conversion completed: {wav_path}") | |
| # Step 2: Generate CHA file using Batchalign | |
| logger.info("Step 2: Generating CHA file...") | |
| cha_path = to_cha_from_wav(wav_path, lang="eng") | |
| logger.info(f"CHA generation completed: {cha_path}") | |
| # Step 3: Convert CHA to JSON | |
| logger.info("Step 3: Converting CHA to JSON...") | |
| chajson_path, json_data = cha_to_json_file(cha_path) | |
| logger.info(f"JSON conversion completed: {chajson_path}") | |
| # Step 4: Run aphasia classification | |
| logger.info("Step 4: Running aphasia classification...") | |
| results = predict_from_chajson(MODEL_DIR, chajson_path, output_file=None) | |
| logger.info("Classification completed") | |
| # Cleanup temporary files | |
| try: | |
| os.unlink(wav_path) | |
| os.unlink(cha_path) | |
| os.unlink(chajson_path) | |
| except Exception as cleanup_error: | |
| logger.warning(f"Cleanup error: {cleanup_error}") | |
| return { | |
| "success": True, | |
| "results": results, | |
| "message": "Pipeline completed successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Pipeline error: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "message": f"Pipeline failed: {str(e)}" | |
| } | |
| def process_audio_input(audio_file): | |
| """ | |
| Process audio file and return formatted results | |
| """ | |
| try: | |
| if audio_file is None: | |
| return ( | |
| "β Error: No audio file uploaded", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ) | |
| # Check file format | |
| file_path = audio_file | |
| if isinstance(audio_file, str): | |
| file_path = audio_file | |
| else: | |
| # Handle Gradio file object | |
| file_path = audio_file.name if hasattr(audio_file, 'name') else str(audio_file) | |
| file_ext = Path(file_path).suffix.lower() | |
| if file_ext not in SUPPORTED_AUDIO_FORMATS: | |
| return ( | |
| f"β Error: Unsupported file format {file_ext}", | |
| f"Supported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}", | |
| "", | |
| "", | |
| "" | |
| ) | |
| # Run the complete pipeline | |
| pipeline_result = run_complete_pipeline(file_path) | |
| if not pipeline_result["success"]: | |
| return ( | |
| f"β Pipeline Error: {pipeline_result['message']}", | |
| pipeline_result.get('error', ''), | |
| "", | |
| "", | |
| "" | |
| ) | |
| # Extract results | |
| results = pipeline_result["results"] | |
| # Format main prediction | |
| if "predictions" in results and len(results["predictions"]) > 0: | |
| first_pred = results["predictions"][0] | |
| if "error" in first_pred: | |
| return ( | |
| f"β Classification Error: {first_pred['error']}", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ) | |
| # Main prediction | |
| predicted_class = first_pred["prediction"]["predicted_class"] | |
| confidence = first_pred["prediction"]["confidence_percentage"] | |
| class_description = first_pred["class_description"]["name"] | |
| main_result = f"π§ **Predicted Aphasia Type:** {predicted_class}\n" | |
| main_result += f"π **Confidence:** {confidence}\n" | |
| main_result += f"π **Description:** {class_description}" | |
| # Detailed analysis | |
| features = first_pred["class_description"].get("features", []) | |
| detailed_analysis = f"**Key Features:**\n" | |
| for feature in features: | |
| detailed_analysis += f"β’ {feature}\n" | |
| detailed_analysis += f"\n**Clinical Description:**\n" | |
| detailed_analysis += first_pred["class_description"].get("description", "No description available") | |
| # Additional metrics | |
| additional_info = first_pred["additional_predictions"] | |
| severity_level = additional_info["predicted_severity_level"] | |
| fluency_score = additional_info["fluency_score"] | |
| fluency_rating = additional_info["fluency_rating"] | |
| additional_metrics = f"**Severity Level:** {severity_level}/3\n" | |
| additional_metrics += f"**Fluency Score:** {fluency_score:.3f} ({fluency_rating})\n" | |
| # Probability distribution (top 3) | |
| prob_dist = first_pred["probability_distribution"] | |
| top_3 = list(prob_dist.items())[:3] | |
| probability_breakdown = "**Top 3 Classifications:**\n" | |
| for i, (aphasia_type, info) in enumerate(top_3, 1): | |
| probability_breakdown += f"{i}. {aphasia_type}: {info['percentage']}\n" | |
| # Summary statistics | |
| summary = results.get("summary", {}) | |
| summary_text = f"**Processing Summary:**\n" | |
| summary_text += f"β’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}\n" | |
| summary_text += f"β’ Average confidence: {summary.get('average_confidence', 'N/A')}\n" | |
| summary_text += f"β’ Average fluency: {summary.get('average_fluency_score', 'N/A')}\n" | |
| return ( | |
| main_result, | |
| detailed_analysis, | |
| additional_metrics, | |
| probability_breakdown, | |
| summary_text | |
| ) | |
| else: | |
| return ( | |
| "β No predictions generated", | |
| "The audio file may not contain analyzable speech", | |
| "", | |
| "", | |
| "" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Processing error: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return ( | |
| f"β Processing Error: {str(e)}", | |
| "Please check the logs for more details", | |
| "", | |
| "", | |
| "" | |
| ) | |
| def process_text_input(text_input): | |
| """ | |
| Process text input directly (fallback option) | |
| """ | |
| try: | |
| if not text_input or not text_input.strip(): | |
| return ( | |
| "β Error: Please enter some text for analysis", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ) | |
| # Create a simple JSON structure for text-only input | |
| temp_json = { | |
| "sentences": [{ | |
| "sentence_id": "S1", | |
| "aphasia_type": "UNKNOWN", | |
| "dialogues": [{ | |
| "INV": [], | |
| "PAR": [{ | |
| "tokens": text_input.split(), | |
| "word_pos_ids": [0] * len(text_input.split()), | |
| "word_grammar_ids": [[0, 0, 0]] * len(text_input.split()), | |
| "word_durations": [0.0] * len(text_input.split()), | |
| "utterance_text": text_input | |
| }] | |
| }] | |
| }], | |
| "text_all": text_input | |
| } | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
| json.dump(temp_json, f, ensure_ascii=False, indent=2) | |
| temp_json_path = f.name | |
| # Run prediction | |
| results = predict_from_chajson(MODEL_DIR, temp_json_path, output_file=None) | |
| # Cleanup | |
| try: | |
| os.unlink(temp_json_path) | |
| except: | |
| pass | |
| # Format results (similar to audio processing) | |
| if "predictions" in results and len(results["predictions"]) > 0: | |
| first_pred = results["predictions"][0] | |
| predicted_class = first_pred["prediction"]["predicted_class"] | |
| confidence = first_pred["prediction"]["confidence_percentage"] | |
| return ( | |
| f"π§ **Predicted:** {predicted_class} ({confidence})", | |
| first_pred["class_description"]["description"], | |
| f"Severity: {first_pred['additional_predictions']['predicted_severity_level']}/3", | |
| f"Fluency: {first_pred['additional_predictions']['fluency_rating']}", | |
| "Text-based analysis completed" | |
| ) | |
| else: | |
| return ( | |
| "β No predictions generated", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Text processing error: {str(e)}") | |
| return ( | |
| f"β Error: {str(e)}", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ) | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create the main Gradio interface""" | |
| with gr.Blocks( | |
| title="Advanced Aphasia Classification System", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-header { text-align: center; margin-bottom: 2rem; } | |
| .upload-section { border: 2px dashed #ccc; padding: 2rem; border-radius: 10px; } | |
| .results-section { margin-top: 2rem; } | |
| """ | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π§ Advanced Aphasia Classification System</h1> | |
| <p>Upload audio files (MP3, MP4, WAV) or enter text to analyze speech patterns and classify aphasia types</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # Audio Input Tab | |
| with gr.TabItem("π΅ Audio Analysis", id="audio_tab"): | |
| gr.Markdown("### Upload Audio File") | |
| gr.Markdown("Supported formats: MP3, MP4, WAV, M4A, FLAC, OGG") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.File( | |
| label="Upload Audio File", | |
| file_types=["audio"], | |
| type="filepath" | |
| ) | |
| process_audio_btn = gr.Button( | |
| "π Analyze Audio", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown("**Note:** Processing may take 1-3 minutes depending on audio length") | |
| # Results section for audio | |
| with gr.Column(scale=2, visible=True) as audio_results: | |
| gr.Markdown("### π Analysis Results") | |
| audio_main_result = gr.Textbox( | |
| label="π― Primary Classification", | |
| lines=3, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| audio_detailed = gr.Textbox( | |
| label="π Detailed Analysis", | |
| lines=6, | |
| interactive=False | |
| ) | |
| audio_metrics = gr.Textbox( | |
| label="π Additional Metrics", | |
| lines=6, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| audio_probabilities = gr.Textbox( | |
| label="π Probability Breakdown", | |
| lines=4, | |
| interactive=False | |
| ) | |
| audio_summary = gr.Textbox( | |
| label="π Processing Summary", | |
| lines=4, | |
| interactive=False | |
| ) | |
| # Text Input Tab (Fallback) | |
| with gr.TabItem("π Text Analysis", id="text_tab"): | |
| gr.Markdown("### Direct Text Input") | |
| gr.Markdown("Enter speech transcription or text for analysis (fallback option)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter speech transcription or text for analysis...", | |
| lines=5 | |
| ) | |
| process_text_btn = gr.Button( | |
| "π Analyze Text", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| # Results section for text | |
| with gr.Column() as text_results: | |
| gr.Markdown("### π Analysis Results") | |
| text_main_result = gr.Textbox( | |
| label="π― Primary Classification", | |
| lines=2, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| text_detailed = gr.Textbox( | |
| label="π Clinical Description", | |
| lines=4, | |
| interactive=False | |
| ) | |
| text_metrics = gr.Textbox( | |
| label="π Metrics", | |
| lines=4, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| text_probabilities = gr.Textbox( | |
| label="π Assessment", | |
| lines=2, | |
| interactive=False | |
| ) | |
| text_summary = gr.Textbox( | |
| label="π Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| # Event handlers | |
| process_audio_btn.click( | |
| fn=process_audio_input, | |
| inputs=[audio_input], | |
| outputs=[ | |
| audio_main_result, | |
| audio_detailed, | |
| audio_metrics, | |
| audio_probabilities, | |
| audio_summary | |
| ] | |
| ) | |
| process_text_btn.click( | |
| fn=process_text_input, | |
| inputs=[text_input], | |
| outputs=[ | |
| text_main_result, | |
| text_detailed, | |
| text_metrics, | |
| text_probabilities, | |
| text_summary | |
| ] | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #eee;"> | |
| <p><strong>About:</strong> This system uses advanced NLP and acoustic analysis to classify different types of aphasia from speech samples.</p> | |
| <p><em>For research and clinical assessment purposes.</em></p> | |
| </div> | |
| """) | |
| return demo | |
| # Launch the application | |
| if __name__ == "__main__": | |
| try: | |
| logger.info("Starting Aphasia Classification System...") | |
| # Check if model directory exists | |
| if not os.path.exists(MODEL_DIR): | |
| logger.error(f"Model directory not found: {MODEL_DIR}") | |
| print(f"β Error: Model directory not found: {MODEL_DIR}") | |
| print("Please ensure your trained model is in the correct directory.") | |
| # Create and launch interface | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to launch app: {e}") | |
| logger.error(traceback.format_exc()) | |
| print(f"β Application startup failed: {e}") |