import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import json import numpy as np import math from transformers import AutoTokenizer, AutoModel from typing import Dict, List, Optional import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Recreate the model classes (simplified versions) class StablePositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 5000): super().__init__() self.d_model = d_model pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) def forward(self, x): seq_len = x.size(1) sinusoidal = self.pe[:, :seq_len, :].to(x.device) learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) return x + 0.1 * (sinusoidal + learnable) class StableMultiHeadAttention(nn.Module): def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): super().__init__() self.num_heads = num_heads self.feature_dim = feature_dim self.head_dim = feature_dim // num_heads assert feature_dim % num_heads == 0 self.query = nn.Linear(feature_dim, feature_dim) self.key = nn.Linear(feature_dim, feature_dim) self.value = nn.Linear(feature_dim, feature_dim) self.dropout = nn.Dropout(dropout) self.output_proj = nn.Linear(feature_dim, feature_dim) self.layer_norm = nn.LayerNorm(feature_dim) def forward(self, x, mask=None): batch_size, seq_len, _ = x.size() Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: if mask.dim() == 2: mask = mask.unsqueeze(1).unsqueeze(1) scores.masked_fill_(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) context = torch.matmul(attn_weights, V) context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim) output = self.output_proj(context) return self.layer_norm(output + x) class StableLinguisticFeatureExtractor(nn.Module): def __init__(self, config): super().__init__() self.config = config # Simplified version - just return zeros if features not available self.pos_embedding = nn.Embedding(config.get('pos_vocab_size', 150), config.get('pos_emb_dim', 64), padding_idx=0) self.pos_attention = StableMultiHeadAttention(config.get('pos_emb_dim', 64), num_heads=4) self.grammar_projection = nn.Sequential( nn.Linear(config.get('grammar_dim', 3), config.get('grammar_hidden_dim', 64)), nn.Tanh(), nn.LayerNorm(config.get('grammar_hidden_dim', 64)), nn.Dropout(config.get('dropout_rate', 0.3) * 0.3) ) self.duration_projection = nn.Sequential( nn.Linear(1, config.get('duration_hidden_dim', 128)), nn.Tanh(), nn.LayerNorm(config.get('duration_hidden_dim', 128)) ) self.prosody_projection = nn.Sequential( nn.Linear(config.get('prosody_dim', 32), config.get('prosody_dim', 32)), nn.ReLU(), nn.LayerNorm(config.get('prosody_dim', 32)) ) total_feature_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) + config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) self.feature_fusion = nn.Sequential( nn.Linear(total_feature_dim, total_feature_dim // 2), nn.Tanh(), nn.LayerNorm(total_feature_dim // 2), nn.Dropout(config.get('dropout_rate', 0.3)) ) def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): batch_size, seq_len = pos_ids.size() # Simple processing - can be expanded later pos_ids_clamped = pos_ids.clamp(0, self.config.get('pos_vocab_size', 150) - 1) pos_embeds = self.pos_embedding(pos_ids_clamped) pos_features = self.pos_attention(pos_embeds, attention_mask) grammar_features = self.grammar_projection(grammar_ids.float()) duration_features = self.duration_projection(durations.unsqueeze(-1).float()) prosody_features = self.prosody_projection(prosody_features.float()) combined_features = torch.cat([ pos_features, grammar_features, duration_features, prosody_features ], dim=-1) fused_features = self.feature_fusion(combined_features) mask_expanded = attention_mask.unsqueeze(-1).float() pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1) return pooled_features class StableAphasiaClassifier(nn.Module): def __init__(self, config, num_labels: int): super().__init__() self.config = config self.num_labels = num_labels try: # Load the base BERT model self.bert = AutoModel.from_pretrained(config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')) self.bert_config = self.bert.config # Freeze embeddings for stability for param in self.bert.embeddings.parameters(): param.requires_grad = False self.positional_encoder = StablePositionalEncoding( d_model=self.bert_config.hidden_size, max_len=config.get('max_length', 512) ) self.linguistic_extractor = StableLinguisticFeatureExtractor(config) bert_dim = self.bert_config.hidden_size linguistic_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) + config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) // 2 self.feature_fusion = nn.Sequential( nn.Linear(bert_dim + linguistic_dim, bert_dim), nn.LayerNorm(bert_dim), nn.Tanh(), nn.Dropout(config.get('dropout_rate', 0.3)) ) # Classifier self.classifier = self._build_classifier(bert_dim, num_labels, config) # Multi-task heads (simplified) self.severity_head = nn.Sequential( nn.Linear(bert_dim, 4), nn.Softmax(dim=-1) ) self.fluency_head = nn.Sequential( nn.Linear(bert_dim, 1), nn.Sigmoid() ) except Exception as e: logger.error(f"Error initializing model: {e}") raise def _build_classifier(self, input_dim: int, num_labels: int, config): layers = [] current_dim = input_dim classifier_hidden_dims = config.get('classifier_hidden_dims', [512, 256]) for hidden_dim in classifier_hidden_dims: layers.extend([ nn.Linear(current_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), nn.Dropout(config.get('dropout_rate', 0.3)) ]) current_dim = hidden_dim layers.append(nn.Linear(current_dim, num_labels)) return nn.Sequential(*layers) def forward(self, input_ids, attention_mask, labels=None, word_pos_ids=None, word_grammar_ids=None, word_durations=None, prosody_features=None, **kwargs): bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) sequence_output = bert_outputs.last_hidden_state position_enhanced = self.positional_encoder(sequence_output) pooled_output = self._attention_pooling(position_enhanced, attention_mask) # Handle missing linguistic features if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): if prosody_features is None: batch_size, seq_len = input_ids.size() prosody_features = torch.zeros( batch_size, seq_len, self.config.get('prosody_dim', 32), device=input_ids.device ) linguistic_features = self.linguistic_extractor( word_pos_ids, word_grammar_ids, word_durations, prosody_features, attention_mask ) else: # Create dummy linguistic features linguistic_features = torch.zeros( input_ids.size(0), (self.config.get('pos_emb_dim', 64) + self.config.get('grammar_hidden_dim', 64) + self.config.get('duration_hidden_dim', 128) + self.config.get('prosody_dim', 32)) // 2, device=input_ids.device ) combined_features = torch.cat([pooled_output, linguistic_features], dim=1) fused_features = self.feature_fusion(combined_features) logits = self.classifier(fused_features) severity_pred = self.severity_head(fused_features) fluency_pred = self.fluency_head(fused_features) return { "logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, } def _attention_pooling(self, sequence_output, attention_mask): attention_weights = torch.softmax( torch.sum(sequence_output, dim=-1, keepdim=True), dim=1 ) attention_weights = attention_weights * attention_mask.unsqueeze(-1).float() attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9) pooled = torch.sum(sequence_output * attention_weights, dim=1) return pooled # Load configuration and model def load_model(): try: # Load configuration with open("config.json", "r") as f: config = json.load(f) logger.info(f"Loaded config: {config}") # Initialize tokenizer model_name = config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext') tokenizer = AutoTokenizer.from_pretrained(model_name) # Add special tokens special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"] tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) # Initialize model num_labels = config.get('num_labels', 9) model_config = config.get('model_config', {}) model = StableAphasiaClassifier(model_config, num_labels) model.bert.resize_token_embeddings(len(tokenizer)) # Load model weights try: state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict) logger.info("Successfully loaded model weights") except Exception as e: logger.error(f"Error loading model weights: {e}") logger.info("Using randomly initialized weights") model.eval() # Get label mapping id2label = config.get('id2label', {}) return model, tokenizer, id2label except Exception as e: logger.error(f"Error loading model: {e}") raise # Initialize model (with error handling) try: model, tokenizer, id2label = load_model() logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") model, tokenizer, id2label = None, None, {} def predict_aphasia(text): """Predict aphasia type from text""" try: if model is None or tokenizer is None: return "Error: Model not loaded properly. Please check the logs.", 0.0, "N/A", 0.0 if not text or not text.strip(): return "Please enter some text for analysis.", 0.0, "N/A", 0.0 # Tokenize input inputs = tokenizer( text, max_length=512, padding="max_length", truncation=True, return_tensors="pt" ) # Create dummy linguistic features (since we don't have them from raw text) batch_size, seq_len = inputs["input_ids"].size() dummy_pos = torch.zeros(batch_size, seq_len, dtype=torch.long) dummy_grammar = torch.zeros(batch_size, seq_len, 3, dtype=torch.long) dummy_durations = torch.zeros(batch_size, seq_len, dtype=torch.float) dummy_prosody = torch.zeros(batch_size, seq_len, 32, dtype=torch.float) # Make prediction with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], word_pos_ids=dummy_pos, word_grammar_ids=dummy_grammar, word_durations=dummy_durations, prosody_features=dummy_prosody ) # Process outputs logits = outputs["logits"] probs = F.softmax(logits, dim=1) predicted_class_id = torch.argmax(probs, dim=1).item() confidence = torch.max(probs, dim=1)[0].item() # Get predicted label predicted_label = id2label.get(str(predicted_class_id), f"Class_{predicted_class_id}") # Get additional predictions severity = torch.argmax(outputs["severity_pred"], dim=1).item() fluency = outputs["fluency_pred"].item() # Format results result = f"Predicted Aphasia Type: {predicted_label}" confidence_str = f"Confidence: {confidence:.2%}" severity_str = f"Severity Level: {severity}/3" fluency_str = f"Fluency Score: {fluency:.3f}" return result, confidence, severity_str, fluency_str except Exception as e: logger.error(f"Prediction error: {e}") return f"Error during prediction: {str(e)}", 0.0, "N/A", 0.0 # Create Gradio interface def create_interface(): """Create Gradio interface""" with gr.Blocks(title="Aphasia Classification System") as demo: gr.Markdown("# 🧠 Advanced Aphasia Classification System") gr.Markdown("Enter speech or text data to classify aphasia type and analyze linguistic patterns.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Input Text", placeholder="Enter speech transcription or text for analysis...", lines=5 ) submit_btn = gr.Button("Analyze Text", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(): prediction_output = gr.Textbox(label="Prediction Result", lines=2) confidence_output = gr.Textbox(label="Confidence Score", lines=1) severity_output = gr.Textbox(label="Severity Analysis", lines=1) fluency_output = gr.Textbox(label="Fluency Analysis", lines=1) # Event handlers submit_btn.click( fn=predict_aphasia, inputs=[text_input], outputs=[prediction_output, confidence_output, severity_output, fluency_output] ) clear_btn.click( fn=lambda: ("", "", "", "", ""), inputs=[], outputs=[text_input, prediction_output, confidence_output, severity_output, fluency_output] ) # Add examples gr.Examples( examples=[ ["The patient... uh... wants to... go home but... cannot... find the words"], ["Woman is... is washing dishes and the... the... sink is overflowing with water everywhere"], ["Cookie is in the cookie jar on the... on the... what do you call it... the shelf thing"] ], inputs=text_input ) gr.Markdown("### About") gr.Markdown("This system uses a specialized transformer model trained on clinical speech data to classify different types of aphasia.") return demo # Launch the app if __name__ == "__main__": try: demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True ) except Exception as e: logger.error(f"Failed to launch app: {e}") print(f"Application startup failed: {e}")