import gradio as gr import torch import json import re from transformers import AutoTokenizer, AutoModelForSequenceClassification import pandas as pd from datetime import datetime import os class AphasiaClassifier: def __init__(self, model_path="./pytorch_model.bin", tokenizer_name="dmis-lab/biobert-base-cased-v1.1"): """ Initialize the Aphasia Classifier Args: model_path: Path to the fine-tuned pytorch_model.bin tokenizer_name: Name of the tokenizer to use (BioBERT) """ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model - you'll need to adjust this based on your model architecture try: # Assuming you have a config.json file with your model configuration self.model = AutoModelForSequenceClassification.from_pretrained( "./", local_files_only=True ) self.model.to(self.device) self.model.eval() except: # Fallback: create a placeholder model structure print("Warning: Could not load model. Using placeholder structure.") self.model = None # Define aphasia severity labels (adjust based on your model's classes) self.severity_labels = { 0: "Normal", 1: "Mild Aphasia", 2: "Moderate Aphasia", 3: "Severe Aphasia" } def preprocess_to_cha(self, text_input): """ Convert text input to CHA format Args: text_input: Raw text input from user Returns: cha_formatted: Text formatted in CHA format """ # Basic CHA formatting - adjust based on your specific CHA requirements lines = text_input.strip().split('\n') cha_formatted = [] for i, line in enumerate(lines): if line.strip(): # Format as CHA with participant markers cha_line = f"*PAR:\t{line.strip()}" cha_formatted.append(cha_line) return '\n'.join(cha_formatted) def cha_to_json(self, cha_text): """ Convert CHA format to JSON structure Args: cha_text: Text in CHA format Returns: json_data: Structured JSON data """ lines = cha_text.split('\n') utterances = [] for line in lines: if line.startswith('*PAR:'): # Extract the actual speech content content = line.replace('*PAR:', '').strip() if content: utterances.append({ "speaker": "PAR", "utterance": content, "timestamp": datetime.now().isoformat() }) json_data = { "session_info": { "date": datetime.now().strftime("%Y-%m-%d"), "participant": "PAR" }, "utterances": utterances } return json_data def classify_text(self, json_data): """ Classify the processed text using the fine-tuned BioBERT model Args: json_data: JSON structured data Returns: classification_results: Classification results in JSON format """ if self.model is None: # Return mock results if model couldn't be loaded return { "prediction": "Mild Aphasia", "confidence": 0.85, "severity_score": 2, "analysis": { "total_utterances": len(json_data["utterances"]), "avg_utterance_length": sum(len(u["utterance"].split()) for u in json_data["utterances"]) / len(json_data["utterances"]) if json_data["utterances"] else 0, "linguistic_features": { "word_finding_difficulties": 0.3, "syntactic_complexity": 0.6, "semantic_appropriateness": 0.8 } }, "timestamp": datetime.now().isoformat(), "model_version": "BioBERT-Aphasia-v1.0" } # Combine all utterances for classification combined_text = " ".join([utterance["utterance"] for utterance in json_data["utterances"]]) # Tokenize the input inputs = self.tokenizer( combined_text, return_tensors="pt", truncation=True, padding=True, max_length=512 ).to(self.device) # Get prediction with torch.no_grad(): outputs = self.model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(predictions, dim=-1).item() confidence = torch.max(predictions).item() # Create detailed results results = { "prediction": self.severity_labels[predicted_class], "confidence": float(confidence), "severity_score": predicted_class, "class_probabilities": { label: float(prob) for label, prob in zip(self.severity_labels.values(), predictions[0].cpu().numpy()) }, "analysis": { "total_utterances": len(json_data["utterances"]), "total_words": len(combined_text.split()), "avg_utterance_length": sum(len(u["utterance"].split()) for u in json_data["utterances"]) / len(json_data["utterances"]) if json_data["utterances"] else 0 }, "timestamp": datetime.now().isoformat(), "model_version": "BioBERT-Aphasia-v1.0" } return results def process_pipeline(self, text_input): """ Complete processing pipeline: text -> CHA -> JSON -> Classification -> Results Args: text_input: Raw text input Returns: tuple: (cha_formatted, json_data, classification_results, formatted_output) """ # Step 1: Convert to CHA format cha_formatted = self.preprocess_to_cha(text_input) # Step 2: Convert CHA to JSON json_data = self.cha_to_json(cha_formatted) # Step 3: Classify using model classification_results = self.classify_text(json_data) # Step 4: Format output for display formatted_output = self.format_results(classification_results) return cha_formatted, json.dumps(json_data, indent=2), json.dumps(classification_results, indent=2), formatted_output def format_results(self, results): """ Format results for user-friendly display """ output = f""" # Aphasia Classification Results ## 🔍 **Prediction**: {results['prediction']} ## 📊 **Confidence**: {results['confidence']:.2%} ## 📈 **Severity Score**: {results['severity_score']}/3 ### Detailed Analysis: - **Total Utterances**: {results['analysis']['total_utterances']} - **Total Words**: {results['analysis'].get('total_words', 'N/A')} - **Average Utterance Length**: {results['analysis']['avg_utterance_length']:.1f} words ### Class Probabilities: """ if 'class_probabilities' in results: for class_name, prob in results['class_probabilities'].items(): bar = "█" * int(prob * 20) # Simple progress bar output += f"- **{class_name}**: {prob:.2%} {bar}\n" output += f"\n*Analysis completed at: {results['timestamp']}*\n" output += f"*Model: {results['model_version']}*" return output # Initialize the classifier classifier = AphasiaClassifier() # Create Gradio interface def process_text(input_text): """ Process text through the complete pipeline """ if not input_text.strip(): return "Please enter some text to analyze.", "", "", "" try: cha_formatted, json_data, classification_json, formatted_results = classifier.process_pipeline(input_text) return cha_formatted, json_data, classification_json, formatted_results except Exception as e: error_msg = f"Error processing text: {str(e)}" return error_msg, "", "", error_msg # Define the Gradio interface with gr.Blocks(title="Aphasia Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🧠 Aphasia Classification System This application uses a fine-tuned BioBERT model to classify speech patterns and identify potential aphasia severity levels. **Pipeline**: Text Input → CHA Format → JSON Structure → BioBERT Classification → Results """) with gr.Row(): with gr.Column(scale=1): input_text = gr.Textbox( label="📝 Speech Input", placeholder="Enter the patient's speech sample here...\nExample: 'The boy is... uh... the boy is climbing the tree. No, wait. The tree... the boy goes up.'", lines=8, max_lines=20 ) classify_btn = gr.Button("🔍 Analyze Speech", variant="primary", size="lg") gr.Markdown(""" ### 💡 Tips: - Enter natural speech samples - Include hesitations, repetitions, and corrections as they occur - Multiple sentences provide better analysis - The model analyzes linguistic patterns and fluency """) with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("📊 Results"): formatted_output = gr.Markdown( label="Analysis Results", value="Enter text and click 'Analyze Speech' to see results here." ) with gr.TabItem("📄 CHA Format"): cha_output = gr.Textbox( label="CHA Formatted Output", lines=6, interactive=False ) with gr.TabItem("🔧 JSON Data"): json_output = gr.Textbox( label="Structured JSON Data", lines=8, interactive=False ) with gr.TabItem("⚙️ Raw Classification"): classification_output = gr.Textbox( label="Raw Classification Results", lines=10, interactive=False ) # Connect the button to the processing function classify_btn.click( fn=process_text, inputs=[input_text], outputs=[cha_output, json_output, classification_output, formatted_output] ) # Example inputs gr.Examples( examples=[ ["The boy is... uh... the boy is climbing the tree. No, wait. The tree... the boy goes up."], ["I want to... to go to the store. Buy some... what do you call it... bread. Yes, bread and milk."], ["The cat sat on the mat. It was a sunny day and the birds were singing in the trees."], ["Doctor, I feel... I feel not good. My head... it hurts here. Since yesterday."] ], inputs=[input_text] ) gr.Markdown(""" --- ### ⚠️ **Disclaimer**: This tool is for research and educational purposes only. It should not be used as a substitute for professional medical diagnosis or treatment. Always consult with qualified healthcare professionals for medical advice. ### 🔧 **Technical Details**: - **Model**: Fine-tuned BioBERT (dmis-lab/biobert-base-cased-v1.1) - **Input**: Natural language speech samples - **Output**: Severity classification (Normal, Mild, Moderate, Severe) - **Features**: CHA formatting, JSON structuring, confidence scores """) # Launch the app if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, # Set to True if you want a public link debug=True )