import torch import torch.nn.functional as F from transformers import AutoTokenizer from huggingface_hub import hf_hub_download from typing import Optional import os import tempfile import whisper import re import traceback from models.base import BaseModelWrapper from .model import MultimodalFusion from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules, parse_segmentation_csv from .config import HF_REPO_ID, WEIGHTS_FILENAME, LOCAL_WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN, SEGMENTATION_ROOT_PATH def find_segmentation_file(audio_filename: str) -> Optional[bytes]: """ Search for segmentation CSV file based on audio filename. Looks in SEGMENTATION_ROOT_PATH/AD/ and SEGMENTATION_ROOT_PATH/Control/. Returns CSV content as bytes if found, None otherwise. """ if not SEGMENTATION_ROOT_PATH or not os.path.exists(SEGMENTATION_ROOT_PATH): return None # Get base filename without extension base_name = os.path.splitext(os.path.basename(audio_filename))[0] # Search in both AD and Control folders for subfolder in ['AD', 'Control']: csv_path = os.path.join(SEGMENTATION_ROOT_PATH, subfolder, f"{base_name}.csv") if os.path.exists(csv_path): print(f" Found segmentation file: {csv_path}") with open(csv_path, 'rb') as f: return f.read() return None class MultimodalWrapper(BaseModelWrapper): def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = None self.tokenizer = None self.asr_model = None self.ling_extractor = LinguisticFeatureExtractor() self.audio_processor = AudioProcessor() def load(self): print("Loading Model V3 components...") self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) self.model = MultimodalFusion(TEXT_MODEL_NAME) # Load Weights - Try local first, then Hugging Face if os.path.exists(LOCAL_WEIGHTS_PATH): weights_path = LOCAL_WEIGHTS_PATH print(f"Loading weights from local: {weights_path}") else: try: print(f"Downloading weights from Hugging Face: {HF_REPO_ID}") weights_path = hf_hub_download( repo_id=HF_REPO_ID, filename=WEIGHTS_FILENAME ) except Exception as e: raise FileNotFoundError(f"Model weights not found locally or on Hugging Face: {e}") state_dict = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() # Load Whisper (Base model as per notebook) print("Loading Whisper for Audio-Only Inference...") self.asr_model = whisper.load_model("base") def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict: """ Handles 4 scenarios with mode-specific visualizations: 1. CHA only: Text + Linguistic features, zero audio 2. CHA + Audio: Text from CHA, timestamps from CHA for slicing 3. Audio + Segmentation CSV: Slice audio using CSV intervals, Whisper ASR 4. Audio only (participant-only): Full audio with Whisper ASR """ try: return self._predict_internal(file_content, filename, audio_content, segmentation_content) except Exception as e: print(f"[Model V3 ERROR] {e}") traceback.print_exc() raise def _predict_internal(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict: # Determine Scenario is_cha_provided = filename.endswith('.cha') and len(file_content) > 0 has_audio = audio_content is not None and len(audio_content) > 0 has_segmentation = segmentation_content is not None and len(segmentation_content) > 0 processed_text = "" raw_text_for_segments = "" ling_features = None ling_vec = None audio_tensor = None intervals = [] tmp_path = None # --- MODE 3 & 4: AUDIO WITHOUT CHA (generate transcript via Whisper) --- if not is_cha_provided and has_audio: # Save audio to temp file for Whisper/Librosa with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: tmp_audio.write(audio_content) tmp_path = tmp_audio.name try: # Try to get segmentation: 1) User provided, 2) Auto-discover from dataset seg_content = segmentation_content if not seg_content: seg_content = find_segmentation_file(filename) # MODE 3: Audio + Segmentation CSV (8.2 style) if seg_content: print("Processing Mode: Audio + Segmentation CSV (8.2 style)") intervals = parse_segmentation_csv(seg_content) print(f" Found {len(intervals)} PAR intervals from CSV") # Slice audio and create spectrogram with intervals audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals) # For transcription, we transcribe the full audio and let Whisper handle it # (In production, you could slice audio first, but Whisper handles full file well) result = self.asr_model.transcribe(tmp_path, word_timestamps=False) chat_transcript = apply_chat_rules(result) # MODE 4: Audio only - participant-only audio (8.3 style) else: print("Processing Mode: Audio Only - Participant Audio (8.3 style)") # No slicing needed - assume entire audio is participant audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None) # Transcribe full audio result = self.asr_model.transcribe(tmp_path, word_timestamps=False) chat_transcript = apply_chat_rules(result) processed_text = chat_transcript raw_text_for_segments = chat_transcript # Extract Features from generated text stats = self.ling_extractor.get_features(chat_transcript) pause_count = chat_transcript.count("[PAUSE]") repetition_count = chat_transcript.count("[/]") # TTR Calc clean_t = re.sub(r'\[.*?\]', '', chat_transcript) clean_t = re.sub(r'[^\w\s]', '', clean_t) words = clean_t.lower().split() n = len(words) if len(words) > 0 else 1 ttr = len(set(words)) / n ling_vec = [ ttr, stats['filler_count'] / n, repetition_count / n, stats['retracing_count'] / n, stats['error_count'] / n, pause_count / n ] ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0) finally: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) # --- SCENARIO 1 & 2: CHA FILE PROVIDED --- else: # Parse Text from CHA text_str = file_content.decode('utf-8', errors='replace') par_lines = [] full_text_for_intervals = "" for line in text_str.splitlines(): if line.startswith('*PAR:'): content = line[5:].strip() par_lines.append(content) full_text_for_intervals += content + " " raw_text = " ".join(par_lines) raw_text_for_segments = raw_text processed_text = self.ling_extractor.clean_for_bert(raw_text) # Extract Features feats = self.ling_extractor.get_feature_vector(raw_text) ling_vec = feats.tolist() ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0) # --- SCENARIO 2: CHA + AUDIO (Segmentation) --- if has_audio: print("Processing Mode: CHA + Audio (Segmentation)") # Extract intervals from the raw text found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals) intervals = [(int(s), int(e)) for s, e in found_intervals] with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: tmp_audio.write(audio_content) tmp_path = tmp_audio.name try: # Pass intervals to slice specific PAR audio audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals) # Generate spectrogram for visualization spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=intervals) finally: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) # --- SCENARIO 1: CHA ONLY --- else: print("Processing Mode: CHA Only") audio_tensor = torch.zeros((1, 3, 224, 224)) # --- COMMON INFERENCE STEPS --- encoding = self.tokenizer.encode_plus( processed_text, add_special_tokens=True, max_length=MAX_LEN, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): input_ids = encoding['input_ids'].to(self.device) mask = encoding['attention_mask'].to(self.device) pixel_values = audio_tensor.to(self.device) ling_input = ling_features.to(self.device) outputs = self.model(input_ids, mask, pixel_values, ling_input) probs = F.softmax(outputs, dim=1) pred_idx = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_idx].item() prob_ad = probs[0][1].item() prob_control = probs[0][0].item() label_map = {0: 'Control', 1: 'AD'} # --- BUILD VISUALIZATIONS --- visualizations = { "probabilities": { "AD": round(prob_ad, 4), "Control": round(prob_control, 4) } } # Key Contribution Segments (from text analysis) # Denser color = more contribution to prediction raw_key_segments = self.ling_extractor.extract_key_segments(raw_text_for_segments) if raw_key_segments: # Transform to contribution-based format (remove markers, use normalized score) max_count = max(seg['marker_count'] for seg in raw_key_segments) if raw_key_segments else 1 key_segments = [ { "text": seg['text'], "contribution_score": round(seg['marker_count'] / max(max_count, 1), 2) } for seg in raw_key_segments ] visualizations["key_contribution_segments"] = { "note": "Denser color indicates higher contribution to prediction", "segments": key_segments } # Audio-specific visualizations removed as per requirements return { "model_version": "v3_multimodal", "filename": filename if filename else "audio_upload", "predicted_label": label_map[pred_idx], "confidence": round(confidence, 4), "modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []), "visualizations": visualizations }