Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| from typing import Optional | |
| import os | |
| import tempfile | |
| import whisper | |
| import re | |
| import traceback | |
| from models.base import BaseModelWrapper | |
| from .model import MultimodalFusion | |
| from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules, parse_segmentation_csv | |
| from .config import HF_REPO_ID, WEIGHTS_FILENAME, LOCAL_WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN, SEGMENTATION_ROOT_PATH | |
| def find_segmentation_file(audio_filename: str) -> Optional[bytes]: | |
| """ | |
| Search for segmentation CSV file based on audio filename. | |
| Looks in SEGMENTATION_ROOT_PATH/AD/ and SEGMENTATION_ROOT_PATH/Control/. | |
| Returns CSV content as bytes if found, None otherwise. | |
| """ | |
| if not SEGMENTATION_ROOT_PATH or not os.path.exists(SEGMENTATION_ROOT_PATH): | |
| return None | |
| # Get base filename without extension | |
| base_name = os.path.splitext(os.path.basename(audio_filename))[0] | |
| # Search in both AD and Control folders | |
| for subfolder in ['AD', 'Control']: | |
| csv_path = os.path.join(SEGMENTATION_ROOT_PATH, subfolder, f"{base_name}.csv") | |
| if os.path.exists(csv_path): | |
| print(f" Found segmentation file: {csv_path}") | |
| with open(csv_path, 'rb') as f: | |
| return f.read() | |
| return None | |
| class MultimodalWrapper(BaseModelWrapper): | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = None | |
| self.tokenizer = None | |
| self.asr_model = None | |
| self.ling_extractor = LinguisticFeatureExtractor() | |
| self.audio_processor = AudioProcessor() | |
| def load(self): | |
| print("Loading Model V3 components...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) | |
| self.model = MultimodalFusion(TEXT_MODEL_NAME) | |
| # Load Weights - Try local first, then Hugging Face | |
| if os.path.exists(LOCAL_WEIGHTS_PATH): | |
| weights_path = LOCAL_WEIGHTS_PATH | |
| print(f"Loading weights from local: {weights_path}") | |
| else: | |
| try: | |
| print(f"Downloading weights from Hugging Face: {HF_REPO_ID}") | |
| weights_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=WEIGHTS_FILENAME | |
| ) | |
| except Exception as e: | |
| raise FileNotFoundError(f"Model weights not found locally or on Hugging Face: {e}") | |
| state_dict = torch.load(weights_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Load Whisper (Base model as per notebook) | |
| print("Loading Whisper for Audio-Only Inference...") | |
| self.asr_model = whisper.load_model("base") | |
| def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict: | |
| """ | |
| Handles 4 scenarios with mode-specific visualizations: | |
| 1. CHA only: Text + Linguistic features, zero audio | |
| 2. CHA + Audio: Text from CHA, timestamps from CHA for slicing | |
| 3. Audio + Segmentation CSV: Slice audio using CSV intervals, Whisper ASR | |
| 4. Audio only (participant-only): Full audio with Whisper ASR | |
| """ | |
| try: | |
| return self._predict_internal(file_content, filename, audio_content, segmentation_content) | |
| except Exception as e: | |
| print(f"[Model V3 ERROR] {e}") | |
| traceback.print_exc() | |
| raise | |
| def _predict_internal(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict: | |
| # Determine Scenario | |
| is_cha_provided = filename.endswith('.cha') and len(file_content) > 0 | |
| has_audio = audio_content is not None and len(audio_content) > 0 | |
| has_segmentation = segmentation_content is not None and len(segmentation_content) > 0 | |
| processed_text = "" | |
| raw_text_for_segments = "" | |
| ling_features = None | |
| ling_vec = None | |
| audio_tensor = None | |
| intervals = [] | |
| tmp_path = None | |
| # --- MODE 3 & 4: AUDIO WITHOUT CHA (generate transcript via Whisper) --- | |
| if not is_cha_provided and has_audio: | |
| # Save audio to temp file for Whisper/Librosa | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: | |
| tmp_audio.write(audio_content) | |
| tmp_path = tmp_audio.name | |
| try: | |
| # Try to get segmentation: 1) User provided, 2) Auto-discover from dataset | |
| seg_content = segmentation_content | |
| if not seg_content: | |
| seg_content = find_segmentation_file(filename) | |
| # MODE 3: Audio + Segmentation CSV (8.2 style) | |
| if seg_content: | |
| print("Processing Mode: Audio + Segmentation CSV (8.2 style)") | |
| intervals = parse_segmentation_csv(seg_content) | |
| print(f" Found {len(intervals)} PAR intervals from CSV") | |
| # Slice audio and create spectrogram with intervals | |
| audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals) | |
| # For transcription, we transcribe the full audio and let Whisper handle it | |
| # (In production, you could slice audio first, but Whisper handles full file well) | |
| result = self.asr_model.transcribe(tmp_path, word_timestamps=False) | |
| chat_transcript = apply_chat_rules(result) | |
| # MODE 4: Audio only - participant-only audio (8.3 style) | |
| else: | |
| print("Processing Mode: Audio Only - Participant Audio (8.3 style)") | |
| # No slicing needed - assume entire audio is participant | |
| audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None) | |
| # Transcribe full audio | |
| result = self.asr_model.transcribe(tmp_path, word_timestamps=False) | |
| chat_transcript = apply_chat_rules(result) | |
| processed_text = chat_transcript | |
| raw_text_for_segments = chat_transcript | |
| # Extract Features from generated text | |
| stats = self.ling_extractor.get_features(chat_transcript) | |
| pause_count = chat_transcript.count("[PAUSE]") | |
| repetition_count = chat_transcript.count("[/]") | |
| # TTR Calc | |
| clean_t = re.sub(r'\[.*?\]', '', chat_transcript) | |
| clean_t = re.sub(r'[^\w\s]', '', clean_t) | |
| words = clean_t.lower().split() | |
| n = len(words) if len(words) > 0 else 1 | |
| ttr = len(set(words)) / n | |
| ling_vec = [ | |
| ttr, | |
| stats['filler_count'] / n, | |
| repetition_count / n, | |
| stats['retracing_count'] / n, | |
| stats['error_count'] / n, | |
| pause_count / n | |
| ] | |
| ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0) | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| # --- SCENARIO 1 & 2: CHA FILE PROVIDED --- | |
| else: | |
| # Parse Text from CHA | |
| text_str = file_content.decode('utf-8', errors='replace') | |
| par_lines = [] | |
| full_text_for_intervals = "" | |
| for line in text_str.splitlines(): | |
| if line.startswith('*PAR:'): | |
| content = line[5:].strip() | |
| par_lines.append(content) | |
| full_text_for_intervals += content + " " | |
| raw_text = " ".join(par_lines) | |
| raw_text_for_segments = raw_text | |
| processed_text = self.ling_extractor.clean_for_bert(raw_text) | |
| # Extract Features | |
| feats = self.ling_extractor.get_feature_vector(raw_text) | |
| ling_vec = feats.tolist() | |
| ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0) | |
| # --- SCENARIO 2: CHA + AUDIO (Segmentation) --- | |
| if has_audio: | |
| print("Processing Mode: CHA + Audio (Segmentation)") | |
| # Extract intervals from the raw text | |
| found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals) | |
| intervals = [(int(s), int(e)) for s, e in found_intervals] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: | |
| tmp_audio.write(audio_content) | |
| tmp_path = tmp_audio.name | |
| try: | |
| # Pass intervals to slice specific PAR audio | |
| audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals) | |
| # Generate spectrogram for visualization | |
| spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=intervals) | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| # --- SCENARIO 1: CHA ONLY --- | |
| else: | |
| print("Processing Mode: CHA Only") | |
| audio_tensor = torch.zeros((1, 3, 224, 224)) | |
| # --- COMMON INFERENCE STEPS --- | |
| encoding = self.tokenizer.encode_plus( | |
| processed_text, | |
| add_special_tokens=True, | |
| max_length=MAX_LEN, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| with torch.no_grad(): | |
| input_ids = encoding['input_ids'].to(self.device) | |
| mask = encoding['attention_mask'].to(self.device) | |
| pixel_values = audio_tensor.to(self.device) | |
| ling_input = ling_features.to(self.device) | |
| outputs = self.model(input_ids, mask, pixel_values, ling_input) | |
| probs = F.softmax(outputs, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_idx].item() | |
| prob_ad = probs[0][1].item() | |
| prob_control = probs[0][0].item() | |
| label_map = {0: 'Control', 1: 'AD'} | |
| # --- BUILD VISUALIZATIONS --- | |
| visualizations = { | |
| "probabilities": { | |
| "AD": round(prob_ad, 4), | |
| "Control": round(prob_control, 4) | |
| } | |
| } | |
| # Key Contribution Segments (from text analysis) | |
| # Denser color = more contribution to prediction | |
| raw_key_segments = self.ling_extractor.extract_key_segments(raw_text_for_segments) | |
| if raw_key_segments: | |
| # Transform to contribution-based format (remove markers, use normalized score) | |
| max_count = max(seg['marker_count'] for seg in raw_key_segments) if raw_key_segments else 1 | |
| key_segments = [ | |
| { | |
| "text": seg['text'], | |
| "contribution_score": round(seg['marker_count'] / max(max_count, 1), 2) | |
| } | |
| for seg in raw_key_segments | |
| ] | |
| visualizations["key_contribution_segments"] = { | |
| "note": "Denser color indicates higher contribution to prediction", | |
| "segments": key_segments | |
| } | |
| # Audio-specific visualizations removed as per requirements | |
| return { | |
| "model_version": "v3_multimodal", | |
| "filename": filename if filename else "audio_upload", | |
| "predicted_label": label_map[pred_idx], | |
| "confidence": round(confidence, 4), | |
| "modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []), | |
| "visualizations": visualizations | |
| } |