| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import json |
| | import numpy as np |
| | import math |
| | from transformers import AutoTokenizer, AutoModel |
| | from typing import Dict, List, Optional |
| | import logging |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | class StablePositionalEncoding(nn.Module): |
| | def __init__(self, d_model: int, max_len: int = 5000): |
| | super().__init__() |
| | self.d_model = d_model |
| | |
| | pe = torch.zeros(max_len, d_model) |
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, d_model, 2).float() * |
| | (-math.log(10000.0) / d_model)) |
| | |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | |
| | self.register_buffer('pe', pe.unsqueeze(0)) |
| | self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) |
| | |
| | def forward(self, x): |
| | seq_len = x.size(1) |
| | sinusoidal = self.pe[:, :seq_len, :].to(x.device) |
| | learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) |
| | return x + 0.1 * (sinusoidal + learnable) |
| |
|
| | class StableMultiHeadAttention(nn.Module): |
| | def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.feature_dim = feature_dim |
| | self.head_dim = feature_dim // num_heads |
| | |
| | assert feature_dim % num_heads == 0 |
| | |
| | self.query = nn.Linear(feature_dim, feature_dim) |
| | self.key = nn.Linear(feature_dim, feature_dim) |
| | self.value = nn.Linear(feature_dim, feature_dim) |
| | self.dropout = nn.Dropout(dropout) |
| | self.output_proj = nn.Linear(feature_dim, feature_dim) |
| | self.layer_norm = nn.LayerNorm(feature_dim) |
| | |
| | def forward(self, x, mask=None): |
| | batch_size, seq_len, _ = x.size() |
| | |
| | Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | |
| | scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| | |
| | if mask is not None: |
| | if mask.dim() == 2: |
| | mask = mask.unsqueeze(1).unsqueeze(1) |
| | scores.masked_fill_(mask == 0, -1e9) |
| | |
| | attn_weights = F.softmax(scores, dim=-1) |
| | attn_weights = self.dropout(attn_weights) |
| | |
| | context = torch.matmul(attn_weights, V) |
| | context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim) |
| | |
| | output = self.output_proj(context) |
| | return self.layer_norm(output + x) |
| |
|
| | class StableLinguisticFeatureExtractor(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | |
| | |
| | self.pos_embedding = nn.Embedding(config.get('pos_vocab_size', 150), config.get('pos_emb_dim', 64), padding_idx=0) |
| | self.pos_attention = StableMultiHeadAttention(config.get('pos_emb_dim', 64), num_heads=4) |
| | |
| | self.grammar_projection = nn.Sequential( |
| | nn.Linear(config.get('grammar_dim', 3), config.get('grammar_hidden_dim', 64)), |
| | nn.Tanh(), |
| | nn.LayerNorm(config.get('grammar_hidden_dim', 64)), |
| | nn.Dropout(config.get('dropout_rate', 0.3) * 0.3) |
| | ) |
| | |
| | self.duration_projection = nn.Sequential( |
| | nn.Linear(1, config.get('duration_hidden_dim', 128)), |
| | nn.Tanh(), |
| | nn.LayerNorm(config.get('duration_hidden_dim', 128)) |
| | ) |
| | |
| | self.prosody_projection = nn.Sequential( |
| | nn.Linear(config.get('prosody_dim', 32), config.get('prosody_dim', 32)), |
| | nn.ReLU(), |
| | nn.LayerNorm(config.get('prosody_dim', 32)) |
| | ) |
| | |
| | total_feature_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) + |
| | config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) |
| | self.feature_fusion = nn.Sequential( |
| | nn.Linear(total_feature_dim, total_feature_dim // 2), |
| | nn.Tanh(), |
| | nn.LayerNorm(total_feature_dim // 2), |
| | nn.Dropout(config.get('dropout_rate', 0.3)) |
| | ) |
| | |
| | def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): |
| | batch_size, seq_len = pos_ids.size() |
| | |
| | |
| | pos_ids_clamped = pos_ids.clamp(0, self.config.get('pos_vocab_size', 150) - 1) |
| | pos_embeds = self.pos_embedding(pos_ids_clamped) |
| | pos_features = self.pos_attention(pos_embeds, attention_mask) |
| | |
| | grammar_features = self.grammar_projection(grammar_ids.float()) |
| | duration_features = self.duration_projection(durations.unsqueeze(-1).float()) |
| | prosody_features = self.prosody_projection(prosody_features.float()) |
| | |
| | combined_features = torch.cat([ |
| | pos_features, grammar_features, duration_features, prosody_features |
| | ], dim=-1) |
| | |
| | fused_features = self.feature_fusion(combined_features) |
| | |
| | mask_expanded = attention_mask.unsqueeze(-1).float() |
| | pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1) |
| | |
| | return pooled_features |
| |
|
| | class StableAphasiaClassifier(nn.Module): |
| | def __init__(self, config, num_labels: int): |
| | super().__init__() |
| | self.config = config |
| | self.num_labels = num_labels |
| | |
| | try: |
| | |
| | self.bert = AutoModel.from_pretrained(config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')) |
| | self.bert_config = self.bert.config |
| | |
| | |
| | for param in self.bert.embeddings.parameters(): |
| | param.requires_grad = False |
| | |
| | self.positional_encoder = StablePositionalEncoding( |
| | d_model=self.bert_config.hidden_size, |
| | max_len=config.get('max_length', 512) |
| | ) |
| | |
| | self.linguistic_extractor = StableLinguisticFeatureExtractor(config) |
| | |
| | bert_dim = self.bert_config.hidden_size |
| | linguistic_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) + |
| | config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) // 2 |
| | |
| | self.feature_fusion = nn.Sequential( |
| | nn.Linear(bert_dim + linguistic_dim, bert_dim), |
| | nn.LayerNorm(bert_dim), |
| | nn.Tanh(), |
| | nn.Dropout(config.get('dropout_rate', 0.3)) |
| | ) |
| | |
| | |
| | self.classifier = self._build_classifier(bert_dim, num_labels, config) |
| | |
| | |
| | self.severity_head = nn.Sequential( |
| | nn.Linear(bert_dim, 4), |
| | nn.Softmax(dim=-1) |
| | ) |
| | |
| | self.fluency_head = nn.Sequential( |
| | nn.Linear(bert_dim, 1), |
| | nn.Sigmoid() |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Error initializing model: {e}") |
| | raise |
| | |
| | def _build_classifier(self, input_dim: int, num_labels: int, config): |
| | layers = [] |
| | current_dim = input_dim |
| | |
| | classifier_hidden_dims = config.get('classifier_hidden_dims', [512, 256]) |
| | for hidden_dim in classifier_hidden_dims: |
| | layers.extend([ |
| | nn.Linear(current_dim, hidden_dim), |
| | nn.LayerNorm(hidden_dim), |
| | nn.Tanh(), |
| | nn.Dropout(config.get('dropout_rate', 0.3)) |
| | ]) |
| | current_dim = hidden_dim |
| | |
| | layers.append(nn.Linear(current_dim, num_labels)) |
| | return nn.Sequential(*layers) |
| | |
| | def forward(self, input_ids, attention_mask, labels=None, |
| | word_pos_ids=None, word_grammar_ids=None, word_durations=None, |
| | prosody_features=None, **kwargs): |
| | |
| | bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| | sequence_output = bert_outputs.last_hidden_state |
| | |
| | position_enhanced = self.positional_encoder(sequence_output) |
| | pooled_output = self._attention_pooling(position_enhanced, attention_mask) |
| | |
| | |
| | if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): |
| | if prosody_features is None: |
| | batch_size, seq_len = input_ids.size() |
| | prosody_features = torch.zeros( |
| | batch_size, seq_len, self.config.get('prosody_dim', 32), |
| | device=input_ids.device |
| | ) |
| | |
| | linguistic_features = self.linguistic_extractor( |
| | word_pos_ids, word_grammar_ids, word_durations, |
| | prosody_features, attention_mask |
| | ) |
| | else: |
| | |
| | linguistic_features = torch.zeros( |
| | input_ids.size(0), |
| | (self.config.get('pos_emb_dim', 64) + self.config.get('grammar_hidden_dim', 64) + |
| | self.config.get('duration_hidden_dim', 128) + self.config.get('prosody_dim', 32)) // 2, |
| | device=input_ids.device |
| | ) |
| | |
| | combined_features = torch.cat([pooled_output, linguistic_features], dim=1) |
| | fused_features = self.feature_fusion(combined_features) |
| | |
| | logits = self.classifier(fused_features) |
| | severity_pred = self.severity_head(fused_features) |
| | fluency_pred = self.fluency_head(fused_features) |
| | |
| | return { |
| | "logits": logits, |
| | "severity_pred": severity_pred, |
| | "fluency_pred": fluency_pred, |
| | } |
| | |
| | def _attention_pooling(self, sequence_output, attention_mask): |
| | attention_weights = torch.softmax( |
| | torch.sum(sequence_output, dim=-1, keepdim=True), dim=1 |
| | ) |
| | attention_weights = attention_weights * attention_mask.unsqueeze(-1).float() |
| | attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9) |
| | pooled = torch.sum(sequence_output * attention_weights, dim=1) |
| | return pooled |
| |
|
| | |
| | def load_model(): |
| | try: |
| | |
| | with open("config.json", "r") as f: |
| | config = json.load(f) |
| | |
| | logger.info(f"Loaded config: {config}") |
| | |
| | |
| | model_name = config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext') |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | |
| | |
| | special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"] |
| | tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) |
| | |
| | |
| | num_labels = config.get('num_labels', 9) |
| | model_config = config.get('model_config', {}) |
| | |
| | model = StableAphasiaClassifier(model_config, num_labels) |
| | model.bert.resize_token_embeddings(len(tokenizer)) |
| | |
| | |
| | try: |
| | state_dict = torch.load("pytorch_model.bin", map_location="cpu") |
| | model.load_state_dict(state_dict) |
| | logger.info("Successfully loaded model weights") |
| | except Exception as e: |
| | logger.error(f"Error loading model weights: {e}") |
| | logger.info("Using randomly initialized weights") |
| | |
| | model.eval() |
| | |
| | |
| | id2label = config.get('id2label', {}) |
| | |
| | return model, tokenizer, id2label |
| | |
| | except Exception as e: |
| | logger.error(f"Error loading model: {e}") |
| | raise |
| |
|
| | |
| | try: |
| | model, tokenizer, id2label = load_model() |
| | logger.info("Model loaded successfully!") |
| | except Exception as e: |
| | logger.error(f"Failed to load model: {e}") |
| | model, tokenizer, id2label = None, None, {} |
| |
|
| | def predict_aphasia(text): |
| | """Predict aphasia type from text""" |
| | try: |
| | if model is None or tokenizer is None: |
| | return "Error: Model not loaded properly. Please check the logs.", 0.0, "N/A", 0.0 |
| | |
| | if not text or not text.strip(): |
| | return "Please enter some text for analysis.", 0.0, "N/A", 0.0 |
| | |
| | |
| | inputs = tokenizer( |
| | text, |
| | max_length=512, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | |
| | |
| | batch_size, seq_len = inputs["input_ids"].size() |
| | dummy_pos = torch.zeros(batch_size, seq_len, dtype=torch.long) |
| | dummy_grammar = torch.zeros(batch_size, seq_len, 3, dtype=torch.long) |
| | dummy_durations = torch.zeros(batch_size, seq_len, dtype=torch.float) |
| | dummy_prosody = torch.zeros(batch_size, seq_len, 32, dtype=torch.float) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model( |
| | input_ids=inputs["input_ids"], |
| | attention_mask=inputs["attention_mask"], |
| | word_pos_ids=dummy_pos, |
| | word_grammar_ids=dummy_grammar, |
| | word_durations=dummy_durations, |
| | prosody_features=dummy_prosody |
| | ) |
| | |
| | |
| | logits = outputs["logits"] |
| | probs = F.softmax(logits, dim=1) |
| | predicted_class_id = torch.argmax(probs, dim=1).item() |
| | confidence = torch.max(probs, dim=1)[0].item() |
| | |
| | |
| | predicted_label = id2label.get(str(predicted_class_id), f"Class_{predicted_class_id}") |
| | |
| | |
| | severity = torch.argmax(outputs["severity_pred"], dim=1).item() |
| | fluency = outputs["fluency_pred"].item() |
| | |
| | |
| | result = f"Predicted Aphasia Type: {predicted_label}" |
| | confidence_str = f"Confidence: {confidence:.2%}" |
| | severity_str = f"Severity Level: {severity}/3" |
| | fluency_str = f"Fluency Score: {fluency:.3f}" |
| | |
| | return result, confidence, severity_str, fluency_str |
| | |
| | except Exception as e: |
| | logger.error(f"Prediction error: {e}") |
| | return f"Error during prediction: {str(e)}", 0.0, "N/A", 0.0 |
| |
|
| | |
| | def create_interface(): |
| | """Create Gradio interface""" |
| | |
| | with gr.Blocks(title="Aphasia Classification System") as demo: |
| | gr.Markdown("# 🧠 Advanced Aphasia Classification System") |
| | gr.Markdown("Enter speech or text data to classify aphasia type and analyze linguistic patterns.") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | text_input = gr.Textbox( |
| | label="Input Text", |
| | placeholder="Enter speech transcription or text for analysis...", |
| | lines=5 |
| | ) |
| | |
| | submit_btn = gr.Button("Analyze Text", variant="primary") |
| | clear_btn = gr.Button("Clear", variant="secondary") |
| | |
| | with gr.Column(): |
| | prediction_output = gr.Textbox(label="Prediction Result", lines=2) |
| | confidence_output = gr.Textbox(label="Confidence Score", lines=1) |
| | severity_output = gr.Textbox(label="Severity Analysis", lines=1) |
| | fluency_output = gr.Textbox(label="Fluency Analysis", lines=1) |
| | |
| | |
| | submit_btn.click( |
| | fn=predict_aphasia, |
| | inputs=[text_input], |
| | outputs=[prediction_output, confidence_output, severity_output, fluency_output] |
| | ) |
| | |
| | clear_btn.click( |
| | fn=lambda: ("", "", "", "", ""), |
| | inputs=[], |
| | outputs=[text_input, prediction_output, confidence_output, severity_output, fluency_output] |
| | ) |
| | |
| | |
| | gr.Examples( |
| | examples=[ |
| | ["The patient... uh... wants to... go home but... cannot... find the words"], |
| | ["Woman is... is washing dishes and the... the... sink is overflowing with water everywhere"], |
| | ["Cookie is in the cookie jar on the... on the... what do you call it... the shelf thing"] |
| | ], |
| | inputs=text_input |
| | ) |
| | |
| | gr.Markdown("### About") |
| | gr.Markdown("This system uses a specialized transformer model trained on clinical speech data to classify different types of aphasia.") |
| | |
| | return demo |
| |
|
| | |
| | if __name__ == "__main__": |
| | try: |
| | demo = create_interface() |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | show_error=True |
| | ) |
| | except Exception as e: |
| | logger.error(f"Failed to launch app: {e}") |
| | print(f"Application startup failed: {e}") |