Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| from models.base import BaseModelWrapper | |
| from .arch import ResearchHybridModel | |
| from .preprocessing import ChaParser | |
| class HybridDebertaWrapper(BaseModelWrapper): | |
| def __init__(self): | |
| self.config = { | |
| 'model_name': 'microsoft/deberta-base', | |
| 'max_seq_len': 64, | |
| 'max_word_len': 40, | |
| 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| 'threshold': 0.20, | |
| 'hf_repo_id': 'cracker0935/bilstm_debert_v1', | |
| 'weights_file': 'best_alzheimer_model.pth' | |
| } | |
| self.model = None | |
| self.tokenizer = None | |
| def load(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name']) | |
| self.model = ResearchHybridModel(model_name=self.config['model_name']) | |
| if os.path.exists(self.config['weights_file']): | |
| weights_path = self.config['weights_file'] | |
| else: | |
| try: | |
| weights_path = hf_hub_download( | |
| repo_id=self.config['hf_repo_id'], | |
| filename=self.config['weights_file'] | |
| ) | |
| except Exception: | |
| raise FileNotFoundError("Model weights not found locally or on Hugging Face.") | |
| state_dict = torch.load(weights_path, map_location=self.config['device']) | |
| if list(state_dict.keys())[0].startswith('module.'): | |
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.config['device']) | |
| self.model.eval() | |
| def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict: | |
| lines = file_content.splitlines() | |
| parser = ChaParser() | |
| sentences, features, _ = parser.parse(lines) | |
| if not sentences: | |
| raise ValueError("No *PAR lines found in file") | |
| if len(sentences) > self.config['max_seq_len']: | |
| sentences = sentences[-self.config['max_seq_len']:] | |
| features = features[-self.config['max_seq_len']:] | |
| encoding = self.tokenizer( | |
| sentences, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.config['max_word_len'], | |
| return_tensors='pt' | |
| ) | |
| ids = encoding['input_ids'].unsqueeze(0).to(self.config['device']) | |
| mask = encoding['attention_mask'].unsqueeze(0).to(self.config['device']) | |
| feats = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.config['device']) | |
| lengths = torch.tensor([len(sentences)]) | |
| with torch.no_grad(): | |
| logits, attn_weights_tensor = self.model(ids, mask, feats, lengths) | |
| prob = F.softmax(logits, dim=1)[:, 1].item() | |
| attn_weights = attn_weights_tensor.cpu().numpy().flatten() | |
| attn_weights = attn_weights[:len(sentences)] | |
| if len(attn_weights) > 0: | |
| w_min, w_max = attn_weights.min(), attn_weights.max() | |
| if w_max - w_min > 0: | |
| attn_weights = (attn_weights - w_min) / (w_max - w_min) | |
| prediction_label = "DEMENTIA" if prob >= self.config['threshold'] else "HEALTHY CONTROL" | |
| attention_map = [] | |
| for sent, score in zip(sentences, attn_weights): | |
| attention_map.append({ | |
| "sentence": sent, | |
| "attention_score": float(score) | |
| }) | |
| return { | |
| "filename": filename, | |
| "prediction": prediction_label, | |
| "confidence": prob, | |
| "is_dementia": prob >= self.config['threshold'], | |
| "attention_map": attention_map, | |
| "model_used": "hybrid deberta" | |
| } |