import torch import torch.nn.functional as F from transformers import AutoTokenizer from huggingface_hub import hf_hub_download import os from models.base import BaseModelWrapper from .arch import ResearchHybridModel from .preprocessing import ChaParser class HybridDebertaWrapper(BaseModelWrapper): def __init__(self): self.config = { 'model_name': 'microsoft/deberta-base', 'max_seq_len': 64, 'max_word_len': 40, 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), 'threshold': 0.20, 'hf_repo_id': 'cracker0935/bilstm_debert_v1', 'weights_file': 'best_alzheimer_model.pth' } self.model = None self.tokenizer = None def load(self): self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name']) self.model = ResearchHybridModel(model_name=self.config['model_name']) if os.path.exists(self.config['weights_file']): weights_path = self.config['weights_file'] else: try: weights_path = hf_hub_download( repo_id=self.config['hf_repo_id'], filename=self.config['weights_file'] ) except Exception: raise FileNotFoundError("Model weights not found locally or on Hugging Face.") state_dict = torch.load(weights_path, map_location=self.config['device']) if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} self.model.load_state_dict(state_dict) self.model.to(self.config['device']) self.model.eval() def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict: lines = file_content.splitlines() parser = ChaParser() sentences, features, _ = parser.parse(lines) if not sentences: raise ValueError("No *PAR lines found in file") if len(sentences) > self.config['max_seq_len']: sentences = sentences[-self.config['max_seq_len']:] features = features[-self.config['max_seq_len']:] encoding = self.tokenizer( sentences, padding='max_length', truncation=True, max_length=self.config['max_word_len'], return_tensors='pt' ) ids = encoding['input_ids'].unsqueeze(0).to(self.config['device']) mask = encoding['attention_mask'].unsqueeze(0).to(self.config['device']) feats = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.config['device']) lengths = torch.tensor([len(sentences)]) with torch.no_grad(): logits, attn_weights_tensor = self.model(ids, mask, feats, lengths) prob = F.softmax(logits, dim=1)[:, 1].item() attn_weights = attn_weights_tensor.cpu().numpy().flatten() attn_weights = attn_weights[:len(sentences)] if len(attn_weights) > 0: w_min, w_max = attn_weights.min(), attn_weights.max() if w_max - w_min > 0: attn_weights = (attn_weights - w_min) / (w_max - w_min) prediction_label = "DEMENTIA" if prob >= self.config['threshold'] else "HEALTHY CONTROL" attention_map = [] for sent, score in zip(sentences, attn_weights): attention_map.append({ "sentence": sent, "attention_score": float(score) }) return { "filename": filename, "prediction": prediction_label, "confidence": prob, "is_dementia": prob >= self.config['threshold'], "attention_map": attention_map, "model_used": "hybrid deberta" }