adtrack-v2 / models /model_v2 /wrapper.py
cracker0935's picture
add mode to model 3
e824b96
import os
import json
import torch
import re
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from models.base import BaseModelWrapper
from .arch import ExplainableModel
from .preprocessing import LiveFeatureExtractor, parse_cha_header, parse_cha_transcript
class ModelV2Wrapper(BaseModelWrapper):
def __init__(self):
self.config_path = os.path.join(os.path.dirname(__file__), "model_config.json")
self.weights_file = "final_alzheimer_hybrid_model.pth"
self.hf_repo_id = "cracker0935/adtrackv2"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.tokenizer = None
self.extractor = None
self.config = None
def load(self):
# Load Config
if not os.path.exists(self.config_path):
try:
print(f"Downloading config from {self.hf_repo_id}...")
self.config_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename="model_config.json"
)
except Exception as e:
print(f"Could not download config: {e}")
with open(self.config_path, 'r') as f:
self.config = json.load(f)
# Load Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name'])
# Load Model Arch
self.model = ExplainableModel(
model_name=self.config['model_name'],
feature_dim=self.config['feature_dim']
)
# Load Weights
if os.path.exists(self.weights_file):
weights_path = self.weights_file
else:
try:
# Fallback to HF if local file missing
weights_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename=self.weights_file
)
except Exception:
# If neither exists, assume it might be in current dir for now
weights_path = self.weights_file
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Load Extractor
self.extractor = LiveFeatureExtractor()
def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict:
content_str = file_content.decode('utf-8')
final_age, final_gender = parse_cha_header(content_str)
raw_text = parse_cha_transcript(content_str)
if not raw_text:
raise ValueError("No participant speech (*PAR) found in file")
clean_text = self.extractor.clean_for_bert(raw_text)
ling_feats = self.extractor.get_vector(raw_text)
sentences = re.split(r'[.?!]\s+', clean_text)
sentences = [s for s in sentences if s.strip()]
if len(sentences) > self.config['max_seq_len']:
sentences = sentences[-self.config['max_seq_len']:]
encoding = self.tokenizer(
sentences,
padding='max_length',
truncation=True,
max_length=self.config['max_word_len'],
return_tensors='pt'
)
input_ids = encoding['input_ids'].unsqueeze(0).to(self.device)
attention_mask = encoding['attention_mask'].unsqueeze(0).to(self.device)
feats_tensor = torch.tensor(ling_feats, dtype=torch.float32).repeat(len(sentences), 1)
pad_len = self.config['max_seq_len'] - len(sentences)
if pad_len > 0:
feats_tensor = torch.cat([feats_tensor, torch.zeros(pad_len, self.config['feature_dim'])])
pad_ids = torch.zeros(pad_len, self.config['max_word_len'], dtype=torch.long).unsqueeze(0).to(self.device)
pad_mask = torch.zeros(pad_len, self.config['max_word_len'], dtype=torch.long).unsqueeze(0).to(self.device)
input_ids = torch.cat([input_ids, pad_ids], dim=1)
attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
feats_tensor = feats_tensor.unsqueeze(0).to(self.device)
lengths = torch.tensor([len(sentences)]).to(self.device)
age_tensor = torch.tensor([final_age / 100.0], dtype=torch.float32).to(self.device)
gender_tensor = torch.tensor([final_gender], dtype=torch.float32).to(self.device)
with torch.no_grad():
logits, attn_weights = self.model(
input_ids,
attention_mask,
feats_tensor,
lengths,
age_tensor,
gender_tensor
)
probs = torch.nn.functional.softmax(logits, dim=1)
dementia_prob = probs[0, 1].item()
predicted_class = "Dementia" if dementia_prob > 0.5 else "Control"
attn_list = attn_weights.cpu().numpy().tolist()
if isinstance(attn_list, float):
attn_list = [attn_list]
top_sentences = []
if len(sentences) > 0:
indexed_attn = list(enumerate(attn_list[:len(sentences)]))
indexed_attn.sort(key=lambda x: x[1], reverse=True)
top_3_indices = [x[0] for x in indexed_attn[:3]]
for idx in top_3_indices:
top_sentences.append({
"text": sentences[idx],
"importance": attn_list[idx]
})
return {
"filename": filename,
"prediction": predicted_class,
"probability_dementia": round(dementia_prob, 4),
"metadata": {
"age": final_age,
"gender": "Male" if final_gender == 1 else "Female",
"sentence_count": len(sentences)
},
"linguistic_features": {
"TTR": ling_feats[0],
"fillers_ratio": ling_feats[1],
"repetitions_ratio": ling_feats[2],
"retracing_ratio": ling_feats[3],
"incomplete_ratio": ling_feats[4],
"pauses_ratio": ling_feats[5]
},
"key_segments": top_sentences,
"model_used": "Model v2"
}