adtrack-v2 / models /model_v1 /wrapper.py
cracker0935's picture
add mode to model 3
e824b96
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os
from models.base import BaseModelWrapper
from .arch import ResearchHybridModel
from .preprocessing import ChaParser
class HybridDebertaWrapper(BaseModelWrapper):
def __init__(self):
self.config = {
'model_name': 'microsoft/deberta-base',
'max_seq_len': 64,
'max_word_len': 40,
'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
'threshold': 0.20,
'hf_repo_id': 'cracker0935/bilstm_debert_v1',
'weights_file': 'best_alzheimer_model.pth'
}
self.model = None
self.tokenizer = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name'])
self.model = ResearchHybridModel(model_name=self.config['model_name'])
if os.path.exists(self.config['weights_file']):
weights_path = self.config['weights_file']
else:
try:
weights_path = hf_hub_download(
repo_id=self.config['hf_repo_id'],
filename=self.config['weights_file']
)
except Exception:
raise FileNotFoundError("Model weights not found locally or on Hugging Face.")
state_dict = torch.load(weights_path, map_location=self.config['device'])
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
self.model.load_state_dict(state_dict)
self.model.to(self.config['device'])
self.model.eval()
def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict:
lines = file_content.splitlines()
parser = ChaParser()
sentences, features, _ = parser.parse(lines)
if not sentences:
raise ValueError("No *PAR lines found in file")
if len(sentences) > self.config['max_seq_len']:
sentences = sentences[-self.config['max_seq_len']:]
features = features[-self.config['max_seq_len']:]
encoding = self.tokenizer(
sentences,
padding='max_length',
truncation=True,
max_length=self.config['max_word_len'],
return_tensors='pt'
)
ids = encoding['input_ids'].unsqueeze(0).to(self.config['device'])
mask = encoding['attention_mask'].unsqueeze(0).to(self.config['device'])
feats = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.config['device'])
lengths = torch.tensor([len(sentences)])
with torch.no_grad():
logits, attn_weights_tensor = self.model(ids, mask, feats, lengths)
prob = F.softmax(logits, dim=1)[:, 1].item()
attn_weights = attn_weights_tensor.cpu().numpy().flatten()
attn_weights = attn_weights[:len(sentences)]
if len(attn_weights) > 0:
w_min, w_max = attn_weights.min(), attn_weights.max()
if w_max - w_min > 0:
attn_weights = (attn_weights - w_min) / (w_max - w_min)
prediction_label = "DEMENTIA" if prob >= self.config['threshold'] else "HEALTHY CONTROL"
attention_map = []
for sent, score in zip(sentences, attn_weights):
attention_map.append({
"sentence": sent,
"attention_score": float(score)
})
return {
"filename": filename,
"prediction": prediction_label,
"confidence": prob,
"is_dementia": prob >= self.config['threshold'],
"attention_map": attention_map,
"model_used": "hybrid deberta"
}