Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import json | |
| import numpy as np | |
| import math | |
| from transformers import AutoTokenizer, AutoModel | |
| from typing import Dict, List, Optional | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Recreate the model classes (simplified versions) | |
| class StablePositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 5000): | |
| super().__init__() | |
| self.d_model = d_model | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * | |
| (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) | |
| def forward(self, x): | |
| seq_len = x.size(1) | |
| sinusoidal = self.pe[:, :seq_len, :].to(x.device) | |
| learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) | |
| return x + 0.1 * (sinusoidal + learnable) | |
| class StableMultiHeadAttention(nn.Module): | |
| def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.feature_dim = feature_dim | |
| self.head_dim = feature_dim // num_heads | |
| assert feature_dim % num_heads == 0 | |
| self.query = nn.Linear(feature_dim, feature_dim) | |
| self.key = nn.Linear(feature_dim, feature_dim) | |
| self.value = nn.Linear(feature_dim, feature_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.output_proj = nn.Linear(feature_dim, feature_dim) | |
| self.layer_norm = nn.LayerNorm(feature_dim) | |
| def forward(self, x, mask=None): | |
| batch_size, seq_len, _ = x.size() | |
| Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| if mask is not None: | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(1).unsqueeze(1) | |
| scores.masked_fill_(mask == 0, -1e9) | |
| attn_weights = F.softmax(scores, dim=-1) | |
| attn_weights = self.dropout(attn_weights) | |
| context = torch.matmul(attn_weights, V) | |
| context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim) | |
| output = self.output_proj(context) | |
| return self.layer_norm(output + x) | |
| class StableLinguisticFeatureExtractor(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| # Simplified version - just return zeros if features not available | |
| self.pos_embedding = nn.Embedding(config.get('pos_vocab_size', 150), config.get('pos_emb_dim', 64), padding_idx=0) | |
| self.pos_attention = StableMultiHeadAttention(config.get('pos_emb_dim', 64), num_heads=4) | |
| self.grammar_projection = nn.Sequential( | |
| nn.Linear(config.get('grammar_dim', 3), config.get('grammar_hidden_dim', 64)), | |
| nn.Tanh(), | |
| nn.LayerNorm(config.get('grammar_hidden_dim', 64)), | |
| nn.Dropout(config.get('dropout_rate', 0.3) * 0.3) | |
| ) | |
| self.duration_projection = nn.Sequential( | |
| nn.Linear(1, config.get('duration_hidden_dim', 128)), | |
| nn.Tanh(), | |
| nn.LayerNorm(config.get('duration_hidden_dim', 128)) | |
| ) | |
| self.prosody_projection = nn.Sequential( | |
| nn.Linear(config.get('prosody_dim', 32), config.get('prosody_dim', 32)), | |
| nn.ReLU(), | |
| nn.LayerNorm(config.get('prosody_dim', 32)) | |
| ) | |
| total_feature_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) + | |
| config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) | |
| self.feature_fusion = nn.Sequential( | |
| nn.Linear(total_feature_dim, total_feature_dim // 2), | |
| nn.Tanh(), | |
| nn.LayerNorm(total_feature_dim // 2), | |
| nn.Dropout(config.get('dropout_rate', 0.3)) | |
| ) | |
| def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): | |
| batch_size, seq_len = pos_ids.size() | |
| # Simple processing - can be expanded later | |
| pos_ids_clamped = pos_ids.clamp(0, self.config.get('pos_vocab_size', 150) - 1) | |
| pos_embeds = self.pos_embedding(pos_ids_clamped) | |
| pos_features = self.pos_attention(pos_embeds, attention_mask) | |
| grammar_features = self.grammar_projection(grammar_ids.float()) | |
| duration_features = self.duration_projection(durations.unsqueeze(-1).float()) | |
| prosody_features = self.prosody_projection(prosody_features.float()) | |
| combined_features = torch.cat([ | |
| pos_features, grammar_features, duration_features, prosody_features | |
| ], dim=-1) | |
| fused_features = self.feature_fusion(combined_features) | |
| mask_expanded = attention_mask.unsqueeze(-1).float() | |
| pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1) | |
| return pooled_features | |
| class StableAphasiaClassifier(nn.Module): | |
| def __init__(self, config, num_labels: int): | |
| super().__init__() | |
| self.config = config | |
| self.num_labels = num_labels | |
| try: | |
| # Load the base BERT model | |
| self.bert = AutoModel.from_pretrained(config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')) | |
| self.bert_config = self.bert.config | |
| # Freeze embeddings for stability | |
| for param in self.bert.embeddings.parameters(): | |
| param.requires_grad = False | |
| self.positional_encoder = StablePositionalEncoding( | |
| d_model=self.bert_config.hidden_size, | |
| max_len=config.get('max_length', 512) | |
| ) | |
| self.linguistic_extractor = StableLinguisticFeatureExtractor(config) | |
| bert_dim = self.bert_config.hidden_size | |
| linguistic_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) + | |
| config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) // 2 | |
| self.feature_fusion = nn.Sequential( | |
| nn.Linear(bert_dim + linguistic_dim, bert_dim), | |
| nn.LayerNorm(bert_dim), | |
| nn.Tanh(), | |
| nn.Dropout(config.get('dropout_rate', 0.3)) | |
| ) | |
| # Classifier | |
| self.classifier = self._build_classifier(bert_dim, num_labels, config) | |
| # Multi-task heads (simplified) | |
| self.severity_head = nn.Sequential( | |
| nn.Linear(bert_dim, 4), | |
| nn.Softmax(dim=-1) | |
| ) | |
| self.fluency_head = nn.Sequential( | |
| nn.Linear(bert_dim, 1), | |
| nn.Sigmoid() | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error initializing model: {e}") | |
| raise | |
| def _build_classifier(self, input_dim: int, num_labels: int, config): | |
| layers = [] | |
| current_dim = input_dim | |
| classifier_hidden_dims = config.get('classifier_hidden_dims', [512, 256]) | |
| for hidden_dim in classifier_hidden_dims: | |
| layers.extend([ | |
| nn.Linear(current_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Tanh(), | |
| nn.Dropout(config.get('dropout_rate', 0.3)) | |
| ]) | |
| current_dim = hidden_dim | |
| layers.append(nn.Linear(current_dim, num_labels)) | |
| return nn.Sequential(*layers) | |
| def forward(self, input_ids, attention_mask, labels=None, | |
| word_pos_ids=None, word_grammar_ids=None, word_durations=None, | |
| prosody_features=None, **kwargs): | |
| bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| sequence_output = bert_outputs.last_hidden_state | |
| position_enhanced = self.positional_encoder(sequence_output) | |
| pooled_output = self._attention_pooling(position_enhanced, attention_mask) | |
| # Handle missing linguistic features | |
| if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): | |
| if prosody_features is None: | |
| batch_size, seq_len = input_ids.size() | |
| prosody_features = torch.zeros( | |
| batch_size, seq_len, self.config.get('prosody_dim', 32), | |
| device=input_ids.device | |
| ) | |
| linguistic_features = self.linguistic_extractor( | |
| word_pos_ids, word_grammar_ids, word_durations, | |
| prosody_features, attention_mask | |
| ) | |
| else: | |
| # Create dummy linguistic features | |
| linguistic_features = torch.zeros( | |
| input_ids.size(0), | |
| (self.config.get('pos_emb_dim', 64) + self.config.get('grammar_hidden_dim', 64) + | |
| self.config.get('duration_hidden_dim', 128) + self.config.get('prosody_dim', 32)) // 2, | |
| device=input_ids.device | |
| ) | |
| combined_features = torch.cat([pooled_output, linguistic_features], dim=1) | |
| fused_features = self.feature_fusion(combined_features) | |
| logits = self.classifier(fused_features) | |
| severity_pred = self.severity_head(fused_features) | |
| fluency_pred = self.fluency_head(fused_features) | |
| return { | |
| "logits": logits, | |
| "severity_pred": severity_pred, | |
| "fluency_pred": fluency_pred, | |
| } | |
| def _attention_pooling(self, sequence_output, attention_mask): | |
| attention_weights = torch.softmax( | |
| torch.sum(sequence_output, dim=-1, keepdim=True), dim=1 | |
| ) | |
| attention_weights = attention_weights * attention_mask.unsqueeze(-1).float() | |
| attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9) | |
| pooled = torch.sum(sequence_output * attention_weights, dim=1) | |
| return pooled | |
| # Load configuration and model | |
| def load_model(): | |
| try: | |
| # Load configuration | |
| with open("config.json", "r") as f: | |
| config = json.load(f) | |
| logger.info(f"Loaded config: {config}") | |
| # Initialize tokenizer | |
| model_name = config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Add special tokens | |
| special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"] | |
| tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) | |
| # Initialize model | |
| num_labels = config.get('num_labels', 9) | |
| model_config = config.get('model_config', {}) | |
| model = StableAphasiaClassifier(model_config, num_labels) | |
| model.bert.resize_token_embeddings(len(tokenizer)) | |
| # Load model weights | |
| try: | |
| state_dict = torch.load("pytorch_model.bin", map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| logger.info("Successfully loaded model weights") | |
| except Exception as e: | |
| logger.error(f"Error loading model weights: {e}") | |
| logger.info("Using randomly initialized weights") | |
| model.eval() | |
| # Get label mapping | |
| id2label = config.get('id2label', {}) | |
| return model, tokenizer, id2label | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| # Initialize model (with error handling) | |
| try: | |
| model, tokenizer, id2label = load_model() | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| model, tokenizer, id2label = None, None, {} | |
| def predict_aphasia(text): | |
| """Predict aphasia type from text""" | |
| try: | |
| if model is None or tokenizer is None: | |
| return "Error: Model not loaded properly. Please check the logs.", 0.0, "N/A", 0.0 | |
| if not text or not text.strip(): | |
| return "Please enter some text for analysis.", 0.0, "N/A", 0.0 | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Create dummy linguistic features (since we don't have them from raw text) | |
| batch_size, seq_len = inputs["input_ids"].size() | |
| dummy_pos = torch.zeros(batch_size, seq_len, dtype=torch.long) | |
| dummy_grammar = torch.zeros(batch_size, seq_len, 3, dtype=torch.long) | |
| dummy_durations = torch.zeros(batch_size, seq_len, dtype=torch.float) | |
| dummy_prosody = torch.zeros(batch_size, seq_len, 32, dtype=torch.float) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| word_pos_ids=dummy_pos, | |
| word_grammar_ids=dummy_grammar, | |
| word_durations=dummy_durations, | |
| prosody_features=dummy_prosody | |
| ) | |
| # Process outputs | |
| logits = outputs["logits"] | |
| probs = F.softmax(logits, dim=1) | |
| predicted_class_id = torch.argmax(probs, dim=1).item() | |
| confidence = torch.max(probs, dim=1)[0].item() | |
| # Get predicted label | |
| predicted_label = id2label.get(str(predicted_class_id), f"Class_{predicted_class_id}") | |
| # Get additional predictions | |
| severity = torch.argmax(outputs["severity_pred"], dim=1).item() | |
| fluency = outputs["fluency_pred"].item() | |
| # Format results | |
| result = f"Predicted Aphasia Type: {predicted_label}" | |
| confidence_str = f"Confidence: {confidence:.2%}" | |
| severity_str = f"Severity Level: {severity}/3" | |
| fluency_str = f"Fluency Score: {fluency:.3f}" | |
| return result, confidence, severity_str, fluency_str | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| return f"Error during prediction: {str(e)}", 0.0, "N/A", 0.0 | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create Gradio interface""" | |
| with gr.Blocks(title="Aphasia Classification System") as demo: | |
| gr.Markdown("# 🧠 Advanced Aphasia Classification System") | |
| gr.Markdown("Enter speech or text data to classify aphasia type and analyze linguistic patterns.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter speech transcription or text for analysis...", | |
| lines=5 | |
| ) | |
| submit_btn = gr.Button("Analyze Text", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Prediction Result", lines=2) | |
| confidence_output = gr.Textbox(label="Confidence Score", lines=1) | |
| severity_output = gr.Textbox(label="Severity Analysis", lines=1) | |
| fluency_output = gr.Textbox(label="Fluency Analysis", lines=1) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=predict_aphasia, | |
| inputs=[text_input], | |
| outputs=[prediction_output, confidence_output, severity_output, fluency_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", "", ""), | |
| inputs=[], | |
| outputs=[text_input, prediction_output, confidence_output, severity_output, fluency_output] | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["The patient... uh... wants to... go home but... cannot... find the words"], | |
| ["Woman is... is washing dishes and the... the... sink is overflowing with water everywhere"], | |
| ["Cookie is in the cookie jar on the... on the... what do you call it... the shelf thing"] | |
| ], | |
| inputs=text_input | |
| ) | |
| gr.Markdown("### About") | |
| gr.Markdown("This system uses a specialized transformer model trained on clinical speech data to classify different types of aphasia.") | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| try: | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to launch app: {e}") | |
| print(f"Application startup failed: {e}") |