File size: 3,959 Bytes
97ea4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e824b96
97ea4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b061f6
97ea4f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os

from models.base import BaseModelWrapper
from .arch import ResearchHybridModel
from .preprocessing import ChaParser

class HybridDebertaWrapper(BaseModelWrapper):
    def __init__(self):
        self.config = {
            'model_name': 'microsoft/deberta-base',
            'max_seq_len': 64,
            'max_word_len': 40,
            'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            'threshold': 0.20,
            'hf_repo_id': 'cracker0935/bilstm_debert_v1', 
            'weights_file': 'best_alzheimer_model.pth'
        }
        self.model = None
        self.tokenizer = None

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name'])
        
        self.model = ResearchHybridModel(model_name=self.config['model_name'])
        
        if os.path.exists(self.config['weights_file']):
            weights_path = self.config['weights_file']
        else:
            try:
                weights_path = hf_hub_download(
                    repo_id=self.config['hf_repo_id'],
                    filename=self.config['weights_file']
                )
            except Exception:
                raise FileNotFoundError("Model weights not found locally or on Hugging Face.")

        state_dict = torch.load(weights_path, map_location=self.config['device'])
        
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[7:]: v for k, v in state_dict.items()}
            
        self.model.load_state_dict(state_dict)
        self.model.to(self.config['device'])
        self.model.eval()

    def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict:
        lines = file_content.splitlines()
        parser = ChaParser()
        sentences, features, _ = parser.parse(lines)
        
        if not sentences:
            raise ValueError("No *PAR lines found in file")

        if len(sentences) > self.config['max_seq_len']:
            sentences = sentences[-self.config['max_seq_len']:]
            features = features[-self.config['max_seq_len']:]

        encoding = self.tokenizer(
            sentences, 
            padding='max_length', 
            truncation=True, 
            max_length=self.config['max_word_len'], 
            return_tensors='pt'
        )
        
        ids = encoding['input_ids'].unsqueeze(0).to(self.config['device'])
        mask = encoding['attention_mask'].unsqueeze(0).to(self.config['device'])
        feats = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.config['device'])
        lengths = torch.tensor([len(sentences)])
        
        with torch.no_grad():
            logits, attn_weights_tensor = self.model(ids, mask, feats, lengths)
            prob = F.softmax(logits, dim=1)[:, 1].item()
            
        attn_weights = attn_weights_tensor.cpu().numpy().flatten()
        attn_weights = attn_weights[:len(sentences)]
        
        if len(attn_weights) > 0:
            w_min, w_max = attn_weights.min(), attn_weights.max()
            if w_max - w_min > 0:
                attn_weights = (attn_weights - w_min) / (w_max - w_min)
        
        prediction_label = "DEMENTIA" if prob >= self.config['threshold'] else "HEALTHY CONTROL"
        
        attention_map = []
        for sent, score in zip(sentences, attn_weights):
            attention_map.append({
                "sentence": sent, 
                "attention_score": float(score)
            })
            
        return {
            "filename": filename,
            "prediction": prediction_label,
            "confidence": prob,
            "is_dementia": prob >= self.config['threshold'],
            "attention_map": attention_map,
            "model_used": "hybrid deberta"
        }