File size: 12,475 Bytes
d5f8ae0
 
 
f141fc2
d5f8ae0
 
 
 
 
f141fc2
d5f8ae0
 
 
e824b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5f8ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f141fc2
 
 
 
d5f8ae0
f141fc2
 
 
 
 
 
 
 
 
 
d5f8ae0
 
 
 
 
 
 
 
e824b96
d5f8ae0
e824b96
 
 
 
 
d5f8ae0
f141fc2
e824b96
f141fc2
 
 
 
 
e824b96
d5f8ae0
 
 
e824b96
d5f8ae0
 
f141fc2
d5f8ae0
f141fc2
d5f8ae0
 
f141fc2
d5f8ae0
e824b96
d5f8ae0
 
 
 
 
 
 
e824b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f141fc2
 
d5f8ae0
e824b96
d5f8ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f141fc2
 
d5f8ae0
e824b96
d5f8ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
f141fc2
d5f8ae0
 
 
 
f141fc2
d5f8ae0
 
 
 
 
f141fc2
d5f8ae0
 
 
 
 
 
 
 
 
 
f141fc2
 
d5f8ae0
f141fc2
 
d5f8ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f141fc2
 
 
d5f8ae0
 
 
f141fc2
 
 
 
 
 
 
 
4fcfcbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f141fc2
4fcfcbc
 
f141fc2
d5f8ae0
 
 
 
 
 
f141fc2
d5f8ae0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from typing import Optional
import os
import tempfile
import whisper
import re
import traceback

from models.base import BaseModelWrapper
from .model import MultimodalFusion
from .processor import LinguisticFeatureExtractor, AudioProcessor, apply_chat_rules, parse_segmentation_csv
from .config import HF_REPO_ID, WEIGHTS_FILENAME, LOCAL_WEIGHTS_PATH, TEXT_MODEL_NAME, MAX_LEN, SEGMENTATION_ROOT_PATH


def find_segmentation_file(audio_filename: str) -> Optional[bytes]:
    """
    Search for segmentation CSV file based on audio filename.
    Looks in SEGMENTATION_ROOT_PATH/AD/ and SEGMENTATION_ROOT_PATH/Control/.
    Returns CSV content as bytes if found, None otherwise.
    """
    if not SEGMENTATION_ROOT_PATH or not os.path.exists(SEGMENTATION_ROOT_PATH):
        return None
    
    # Get base filename without extension
    base_name = os.path.splitext(os.path.basename(audio_filename))[0]
    
    # Search in both AD and Control folders
    for subfolder in ['AD', 'Control']:
        csv_path = os.path.join(SEGMENTATION_ROOT_PATH, subfolder, f"{base_name}.csv")
        if os.path.exists(csv_path):
            print(f"  Found segmentation file: {csv_path}")
            with open(csv_path, 'rb') as f:
                return f.read()
    
    return None

class MultimodalWrapper(BaseModelWrapper):
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.tokenizer = None
        self.asr_model = None
        self.ling_extractor = LinguisticFeatureExtractor()
        self.audio_processor = AudioProcessor()

    def load(self):
        print("Loading Model V3 components...")
        self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
        self.model = MultimodalFusion(TEXT_MODEL_NAME)
        
        # Load Weights - Try local first, then Hugging Face
        if os.path.exists(LOCAL_WEIGHTS_PATH):
            weights_path = LOCAL_WEIGHTS_PATH
            print(f"Loading weights from local: {weights_path}")
        else:
            try:
                print(f"Downloading weights from Hugging Face: {HF_REPO_ID}")
                weights_path = hf_hub_download(
                    repo_id=HF_REPO_ID,
                    filename=WEIGHTS_FILENAME
                )
            except Exception as e:
                raise FileNotFoundError(f"Model weights not found locally or on Hugging Face: {e}")
        
        state_dict = torch.load(weights_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()

        # Load Whisper (Base model as per notebook)
        print("Loading Whisper for Audio-Only Inference...")
        self.asr_model = whisper.load_model("base")

    def predict(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict:
        """
        Handles 4 scenarios with mode-specific visualizations:
        1. CHA only: Text + Linguistic features, zero audio
        2. CHA + Audio: Text from CHA, timestamps from CHA for slicing
        3. Audio + Segmentation CSV: Slice audio using CSV intervals, Whisper ASR
        4. Audio only (participant-only): Full audio with Whisper ASR
        """
        try:
            return self._predict_internal(file_content, filename, audio_content, segmentation_content)
        except Exception as e:
            print(f"[Model V3 ERROR] {e}")
            traceback.print_exc()
            raise

    def _predict_internal(self, file_content: bytes, filename: str, audio_content: Optional[bytes] = None, segmentation_content: Optional[bytes] = None) -> dict:
        # Determine Scenario
        is_cha_provided = filename.endswith('.cha') and len(file_content) > 0
        has_audio = audio_content is not None and len(audio_content) > 0
        has_segmentation = segmentation_content is not None and len(segmentation_content) > 0
        
        processed_text = ""
        raw_text_for_segments = ""
        ling_features = None
        ling_vec = None
        audio_tensor = None
        intervals = []
        tmp_path = None
        
        # --- MODE 3 & 4: AUDIO WITHOUT CHA (generate transcript via Whisper) ---
        if not is_cha_provided and has_audio:
            # Save audio to temp file for Whisper/Librosa
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
                tmp_audio.write(audio_content)
                tmp_path = tmp_audio.name
            
            try:
                # Try to get segmentation: 1) User provided, 2) Auto-discover from dataset
                seg_content = segmentation_content
                if not seg_content:
                    seg_content = find_segmentation_file(filename)
                
                # MODE 3: Audio + Segmentation CSV (8.2 style)
                if seg_content:
                    print("Processing Mode: Audio + Segmentation CSV (8.2 style)")
                    intervals = parse_segmentation_csv(seg_content)
                    print(f"  Found {len(intervals)} PAR intervals from CSV")
                    
                    # Slice audio and create spectrogram with intervals
                    audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
                    
                    # For transcription, we transcribe the full audio and let Whisper handle it
                    # (In production, you could slice audio first, but Whisper handles full file well)
                    result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
                    chat_transcript = apply_chat_rules(result)
                    
                # MODE 4: Audio only - participant-only audio (8.3 style)
                else:
                    print("Processing Mode: Audio Only - Participant Audio (8.3 style)")
                    # No slicing needed - assume entire audio is participant
                    audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=None)
                    
                    # Transcribe full audio
                    result = self.asr_model.transcribe(tmp_path, word_timestamps=False)
                    chat_transcript = apply_chat_rules(result)
                
                processed_text = chat_transcript
                raw_text_for_segments = chat_transcript
                
                # Extract Features from generated text
                stats = self.ling_extractor.get_features(chat_transcript)
                pause_count = chat_transcript.count("[PAUSE]")
                repetition_count = chat_transcript.count("[/]")
                
                # TTR Calc
                clean_t = re.sub(r'\[.*?\]', '', chat_transcript)
                clean_t = re.sub(r'[^\w\s]', '', clean_t)
                words = clean_t.lower().split()
                n = len(words) if len(words) > 0 else 1
                ttr = len(set(words)) / n
                
                ling_vec = [
                    ttr,
                    stats['filler_count'] / n,
                    repetition_count / n,
                    stats['retracing_count'] / n,
                    stats['error_count'] / n,
                    pause_count / n
                ]
                ling_features = torch.tensor(ling_vec, dtype=torch.float32).unsqueeze(0)
                
            finally:
                if tmp_path and os.path.exists(tmp_path):
                    os.remove(tmp_path)


        # --- SCENARIO 1 & 2: CHA FILE PROVIDED ---
        else:
            # Parse Text from CHA
            text_str = file_content.decode('utf-8', errors='replace')
            par_lines = []
            full_text_for_intervals = ""
            
            for line in text_str.splitlines():
                if line.startswith('*PAR:'):
                    content = line[5:].strip()
                    par_lines.append(content)
                    full_text_for_intervals += content + " "

            raw_text = " ".join(par_lines)
            raw_text_for_segments = raw_text
            processed_text = self.ling_extractor.clean_for_bert(raw_text)
            
            # Extract Features
            feats = self.ling_extractor.get_feature_vector(raw_text)
            ling_vec = feats.tolist()
            ling_features = torch.tensor(feats, dtype=torch.float32).unsqueeze(0)

            # --- SCENARIO 2: CHA + AUDIO (Segmentation) ---
            if has_audio:
                print("Processing Mode: CHA + Audio (Segmentation)")
                # Extract intervals from the raw text
                found_intervals = re.findall(r'\x15(\d+)_(\d+)\x15', full_text_for_intervals)
                intervals = [(int(s), int(e)) for s, e in found_intervals]
                
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
                    tmp_audio.write(audio_content)
                    tmp_path = tmp_audio.name
                
                try:
                    # Pass intervals to slice specific PAR audio
                    audio_tensor = self.audio_processor.create_spectrogram_tensor(tmp_path, intervals=intervals)
                    # Generate spectrogram for visualization
                    spectrogram_b64 = self.audio_processor.create_spectrogram_base64(tmp_path, intervals=intervals)
                finally:
                    if tmp_path and os.path.exists(tmp_path):
                        os.remove(tmp_path)
            
            # --- SCENARIO 1: CHA ONLY ---
            else:
                print("Processing Mode: CHA Only")
                audio_tensor = torch.zeros((1, 3, 224, 224))

        # --- COMMON INFERENCE STEPS ---
        encoding = self.tokenizer.encode_plus(
            processed_text,
            add_special_tokens=True,
            max_length=MAX_LEN,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        with torch.no_grad():
            input_ids = encoding['input_ids'].to(self.device)
            mask = encoding['attention_mask'].to(self.device)
            pixel_values = audio_tensor.to(self.device)
            ling_input = ling_features.to(self.device)

            outputs = self.model(input_ids, mask, pixel_values, ling_input)
            probs = F.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            confidence = probs[0][pred_idx].item()
            
            prob_ad = probs[0][1].item()
            prob_control = probs[0][0].item()

        label_map = {0: 'Control', 1: 'AD'}
        
        # --- BUILD VISUALIZATIONS ---
        visualizations = {
            "probabilities": {
                "AD": round(prob_ad, 4),
                "Control": round(prob_control, 4)
            }
        }
        
        # Key Contribution Segments (from text analysis)
        # Denser color = more contribution to prediction
        raw_key_segments = self.ling_extractor.extract_key_segments(raw_text_for_segments)
        if raw_key_segments:
            # Transform to contribution-based format (remove markers, use normalized score)
            max_count = max(seg['marker_count'] for seg in raw_key_segments) if raw_key_segments else 1
            key_segments = [
                {
                    "text": seg['text'],
                    "contribution_score": round(seg['marker_count'] / max(max_count, 1), 2)
                }
                for seg in raw_key_segments
            ]
            visualizations["key_contribution_segments"] = {
                "note": "Denser color indicates higher contribution to prediction",
                "segments": key_segments
            }
        
        # Audio-specific visualizations removed as per requirements
        
        return {
            "model_version": "v3_multimodal",
            "filename": filename if filename else "audio_upload",
            "predicted_label": label_map[pred_idx],
            "confidence": round(confidence, 4),
            "modalities_used": ["text", "linguistic"] + (["audio"] if has_audio else []),
            "visualizations": visualizations
        }