Spaces:
Running
Running
| import os | |
| import json | |
| import torch | |
| import re | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| from models.base import BaseModelWrapper | |
| from .arch import ExplainableModel | |
| from .preprocessing import LiveFeatureExtractor, parse_cha_header, parse_cha_transcript | |
| class ModelV2Wrapper(BaseModelWrapper): | |
| def __init__(self): | |
| self.config_path = os.path.join(os.path.dirname(__file__), "model_config.json") | |
| self.weights_file = "final_alzheimer_hybrid_model.pth" | |
| self.hf_repo_id = "cracker0935/adtrackv2" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.tokenizer = None | |
| self.extractor = None | |
| self.config = None | |
| def load(self): | |
| # Load Config | |
| if not os.path.exists(self.config_path): | |
| try: | |
| print(f"Downloading config from {self.hf_repo_id}...") | |
| self.config_path = hf_hub_download( | |
| repo_id=self.hf_repo_id, | |
| filename="model_config.json" | |
| ) | |
| except Exception as e: | |
| print(f"Could not download config: {e}") | |
| with open(self.config_path, 'r') as f: | |
| self.config = json.load(f) | |
| # Load Tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name']) | |
| # Load Model Arch | |
| self.model = ExplainableModel( | |
| model_name=self.config['model_name'], | |
| feature_dim=self.config['feature_dim'] | |
| ) | |
| # Load Weights | |
| if os.path.exists(self.weights_file): | |
| weights_path = self.weights_file | |
| else: | |
| try: | |
| # Fallback to HF if local file missing | |
| weights_path = hf_hub_download( | |
| repo_id=self.hf_repo_id, | |
| filename=self.weights_file | |
| ) | |
| except Exception: | |
| # If neither exists, assume it might be in current dir for now | |
| weights_path = self.weights_file | |
| state_dict = torch.load(weights_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Load Extractor | |
| self.extractor = LiveFeatureExtractor() | |
| def predict(self, file_content: bytes, filename: str, audio_content=None, segmentation_content=None) -> dict: | |
| content_str = file_content.decode('utf-8') | |
| final_age, final_gender = parse_cha_header(content_str) | |
| raw_text = parse_cha_transcript(content_str) | |
| if not raw_text: | |
| raise ValueError("No participant speech (*PAR) found in file") | |
| clean_text = self.extractor.clean_for_bert(raw_text) | |
| ling_feats = self.extractor.get_vector(raw_text) | |
| sentences = re.split(r'[.?!]\s+', clean_text) | |
| sentences = [s for s in sentences if s.strip()] | |
| if len(sentences) > self.config['max_seq_len']: | |
| sentences = sentences[-self.config['max_seq_len']:] | |
| encoding = self.tokenizer( | |
| sentences, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.config['max_word_len'], | |
| return_tensors='pt' | |
| ) | |
| input_ids = encoding['input_ids'].unsqueeze(0).to(self.device) | |
| attention_mask = encoding['attention_mask'].unsqueeze(0).to(self.device) | |
| feats_tensor = torch.tensor(ling_feats, dtype=torch.float32).repeat(len(sentences), 1) | |
| pad_len = self.config['max_seq_len'] - len(sentences) | |
| if pad_len > 0: | |
| feats_tensor = torch.cat([feats_tensor, torch.zeros(pad_len, self.config['feature_dim'])]) | |
| pad_ids = torch.zeros(pad_len, self.config['max_word_len'], dtype=torch.long).unsqueeze(0).to(self.device) | |
| pad_mask = torch.zeros(pad_len, self.config['max_word_len'], dtype=torch.long).unsqueeze(0).to(self.device) | |
| input_ids = torch.cat([input_ids, pad_ids], dim=1) | |
| attention_mask = torch.cat([attention_mask, pad_mask], dim=1) | |
| feats_tensor = feats_tensor.unsqueeze(0).to(self.device) | |
| lengths = torch.tensor([len(sentences)]).to(self.device) | |
| age_tensor = torch.tensor([final_age / 100.0], dtype=torch.float32).to(self.device) | |
| gender_tensor = torch.tensor([final_gender], dtype=torch.float32).to(self.device) | |
| with torch.no_grad(): | |
| logits, attn_weights = self.model( | |
| input_ids, | |
| attention_mask, | |
| feats_tensor, | |
| lengths, | |
| age_tensor, | |
| gender_tensor | |
| ) | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| dementia_prob = probs[0, 1].item() | |
| predicted_class = "Dementia" if dementia_prob > 0.5 else "Control" | |
| attn_list = attn_weights.cpu().numpy().tolist() | |
| if isinstance(attn_list, float): | |
| attn_list = [attn_list] | |
| top_sentences = [] | |
| if len(sentences) > 0: | |
| indexed_attn = list(enumerate(attn_list[:len(sentences)])) | |
| indexed_attn.sort(key=lambda x: x[1], reverse=True) | |
| top_3_indices = [x[0] for x in indexed_attn[:3]] | |
| for idx in top_3_indices: | |
| top_sentences.append({ | |
| "text": sentences[idx], | |
| "importance": attn_list[idx] | |
| }) | |
| return { | |
| "filename": filename, | |
| "prediction": predicted_class, | |
| "probability_dementia": round(dementia_prob, 4), | |
| "metadata": { | |
| "age": final_age, | |
| "gender": "Male" if final_gender == 1 else "Female", | |
| "sentence_count": len(sentences) | |
| }, | |
| "linguistic_features": { | |
| "TTR": ling_feats[0], | |
| "fillers_ratio": ling_feats[1], | |
| "repetitions_ratio": ling_feats[2], | |
| "retracing_ratio": ling_feats[3], | |
| "incomplete_ratio": ling_feats[4], | |
| "pauses_ratio": ling_feats[5] | |
| }, | |
| "key_segments": top_sentences, | |
| "model_used": "Model v2" | |
| } |