import os import json import torch import re from transformers import AutoTokenizer from huggingface_hub import hf_hub_download from models.base import BaseModelWrapper from .arch import ExplainableModel from .preprocessing import LiveFeatureExtractor, parse_cha_header, parse_cha_transcript class ModelV2Wrapper(BaseModelWrapper): def __init__(self): self.config_path = os.path.join(os.path.dirname(__file__), "model_config.json") self.weights_file = "final_alzheimer_hybrid_model.pth" self.hf_repo_id = "cracker0935/adtrackv2" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.tokenizer = None self.extractor = None self.config = None def load(self): # Load Config if not os.path.exists(self.config_path): try: print(f"Downloading config from {self.hf_repo_id}...") self.config_path = hf_hub_download( repo_id=self.hf_repo_id, filename="model_config.json" ) except Exception as e: print(f"Could not download config: {e}") with open(self.config_path, 'r') as f: self.config = json.load(f) # Load Tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name']) # Load Model Arch self.model = ExplainableModel( model_name=self.config['model_name'], feature_dim=self.config['feature_dim'] ) # Load Weights if os.path.exists(self.weights_file): weights_path = self.weights_file else: try: # Fallback to HF if local file missing weights_path = hf_hub_download( repo_id=self.hf_repo_id, filename=self.weights_file ) except Exception: # If neither exists, assume it might be in current dir for now weights_path = self.weights_file state_dict = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() # Load Extractor self.extractor = LiveFeatureExtractor() def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict: content_str = file_content.decode('utf-8') final_age, final_gender = parse_cha_header(content_str) raw_text = parse_cha_transcript(content_str) if not raw_text: raise ValueError("No participant speech (*PAR) found in file") clean_text = self.extractor.clean_for_bert(raw_text) ling_feats = self.extractor.get_vector(raw_text) sentences = re.split(r'[.?!]\s+', clean_text) sentences = [s for s in sentences if s.strip()] if len(sentences) > self.config['max_seq_len']: sentences = sentences[-self.config['max_seq_len']:] encoding = self.tokenizer( sentences, padding='max_length', truncation=True, max_length=self.config['max_word_len'], return_tensors='pt' ) input_ids = encoding['input_ids'].unsqueeze(0).to(self.device) attention_mask = encoding['attention_mask'].unsqueeze(0).to(self.device) feats_tensor = torch.tensor(ling_feats, dtype=torch.float32).repeat(len(sentences), 1) pad_len = self.config['max_seq_len'] - len(sentences) if pad_len > 0: feats_tensor = torch.cat([feats_tensor, torch.zeros(pad_len, self.config['feature_dim'])]) pad_ids = torch.zeros(pad_len, self.config['max_word_len'], dtype=torch.long).unsqueeze(0).to(self.device) pad_mask = torch.zeros(pad_len, self.config['max_word_len'], dtype=torch.long).unsqueeze(0).to(self.device) input_ids = torch.cat([input_ids, pad_ids], dim=1) attention_mask = torch.cat([attention_mask, pad_mask], dim=1) feats_tensor = feats_tensor.unsqueeze(0).to(self.device) lengths = torch.tensor([len(sentences)]).to(self.device) age_tensor = torch.tensor([final_age / 100.0], dtype=torch.float32).to(self.device) gender_tensor = torch.tensor([final_gender], dtype=torch.float32).to(self.device) with torch.no_grad(): logits, attn_weights = self.model( input_ids, attention_mask, feats_tensor, lengths, age_tensor, gender_tensor ) probs = torch.nn.functional.softmax(logits, dim=1) dementia_prob = probs[0, 1].item() predicted_class = "Dementia" if dementia_prob > 0.5 else "Control" attn_list = attn_weights.cpu().numpy().tolist() if isinstance(attn_list, float): attn_list = [attn_list] top_sentences = [] if len(sentences) > 0: indexed_attn = list(enumerate(attn_list[:len(sentences)])) indexed_attn.sort(key=lambda x: x[1], reverse=True) top_3_indices = [x[0] for x in indexed_attn[:3]] for idx in top_3_indices: top_sentences.append({ "text": sentences[idx], "importance": attn_list[idx] }) return { "filename": filename, "prediction": predicted_class, "probability_dementia": round(dementia_prob, 4), "metadata": { "age": final_age, "gender": "Male" if final_gender == 1 else "Female", "sentence_count": len(sentences) }, "linguistic_features": { "TTR": ling_feats[0], "fillers_ratio": ling_feats[1], "repetitions_ratio": ling_feats[2], "retracing_ratio": ling_feats[3], "incomplete_ratio": ling_feats[4], "pauses_ratio": ling_feats[5] }, "key_segments": top_sentences, "model_used": "Model v2" }