ryu34 commited on
Commit
42aa088
Β·
verified Β·
1 Parent(s): 257885d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +1123 -0
app.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Brain Encoder - Gradio Application
3
+ =============================================
4
+ Full end-to-end system:
5
+ Input β†’ CLIP Features β†’ Brain Prediction β†’ ROI Analysis β†’ LLM Q&A β†’ Visualization
6
+
7
+ Uses real trained weights from NSD dataset.
8
+ LLM is an INTERPRETER only - grounded in model predictions, not independent.
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import json
14
+ import time
15
+ import logging
16
+ import pickle
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from PIL import Image
22
+ from pathlib import Path
23
+ from datetime import datetime
24
+ from collections import OrderedDict
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # ============================================================
30
+ # Configuration (must match training)
31
+ # ============================================================
32
+ MODEL_REPO = os.environ.get("MODEL_REPO", "ryu34/multimodal-brain-encoder")
33
+
34
+ ROI_NAMES = {
35
+ 1: "V1v", 2: "V1d", 3: "V2v", 4: "V2d", 5: "V3v", 6: "V3d", 7: "hV4",
36
+ 8: "EBA", 9: "FBA-1", 10: "FBA-2", 11: "mTL-bodies",
37
+ 12: "OFA", 13: "FFA-1", 14: "FFA-2", 15: "mTL-faces", 16: "aTL-faces",
38
+ 17: "OPA", 18: "PPA", 19: "RSC",
39
+ 20: "OWFA", 21: "VWFA-1", 22: "VWFA-2", 23: "mfs-words", 24: "mTL-words",
40
+ }
41
+
42
+ FUNCTIONAL_NETWORKS = {
43
+ "early_visual": [1, 2, 3, 4, 5, 6, 7],
44
+ "body_selective": [8, 9, 10, 11],
45
+ "face_selective": [12, 13, 14, 15, 16],
46
+ "place_selective": [17, 18, 19],
47
+ "word_selective": [20, 21, 22, 23, 24],
48
+ }
49
+
50
+ # Known neuroscience associations for grounded Q&A
51
+ ROI_FUNCTIONS = {
52
+ "V1v": "Primary visual cortex (ventral); processes basic visual features: edges, orientations, spatial frequencies",
53
+ "V1d": "Primary visual cortex (dorsal); processes basic visual features with dorsal visual stream emphasis",
54
+ "V2v": "Secondary visual cortex (ventral); processes texture, figure-ground segregation",
55
+ "V2d": "Secondary visual cortex (dorsal); processes contour and border ownership",
56
+ "V3v": "Third visual area (ventral); contributes to form perception and shape processing",
57
+ "V3d": "Third visual area (dorsal); processes dynamic form and motion boundaries",
58
+ "hV4": "Human V4; processes color, pattern, moderate object features, texture discrimination",
59
+ "EBA": "Extrastriate Body Area; selectively responds to bodies and body parts",
60
+ "FBA-1": "Fusiform Body Area 1; body processing in ventral temporal cortex",
61
+ "FBA-2": "Fusiform Body Area 2; complementary body processing region",
62
+ "mTL-bodies": "Medial temporal lobe body area; body recognition with memory component",
63
+ "OFA": "Occipital Face Area; early face-selective processing, face parts detection",
64
+ "FFA-1": "Fusiform Face Area 1; core face recognition and identity processing",
65
+ "FFA-2": "Fusiform Face Area 2; complementary face processing, holistic face representation",
66
+ "mTL-faces": "Medial temporal lobe face area; face recognition with episodic memory",
67
+ "aTL-faces": "Anterior temporal lobe face area; person identity and semantic knowledge",
68
+ "OPA": "Occipital Place Area; processes local scene elements and spatial boundaries",
69
+ "PPA": "Parahippocampal Place Area; processes scenes, buildings, spatial layouts",
70
+ "RSC": "Retrosplenial Cortex; spatial navigation, scene-to-map coordinate transformation",
71
+ "OWFA": "Occipital Word Form Area; early visual word processing",
72
+ "VWFA-1": "Visual Word Form Area 1; processes written words and letter strings",
73
+ "VWFA-2": "Visual Word Form Area 2; higher-level word form processing",
74
+ "mfs-words": "Mid-fusiform sulcus word area; intermediate word processing",
75
+ "mTL-words": "Medial temporal lobe word area; word recognition with memory",
76
+ }
77
+
78
+ NETWORK_FUNCTIONS = {
79
+ "early_visual": "Early visual processing: edges, orientations, spatial frequencies, textures, colors. Active for all visual stimuli.",
80
+ "body_selective": "Body-selective cortex: responds to human bodies, body parts, biological motion. Key for person perception.",
81
+ "face_selective": "Face-selective cortex: responds to faces, facial features, identity. Critical for social perception.",
82
+ "place_selective": "Place/scene-selective cortex: responds to spatial layouts, buildings, scenes, navigation cues.",
83
+ "word_selective": "Word/reading-selective cortex: responds to written text, letter strings, word forms.",
84
+ }
85
+
86
+
87
+ # ============================================================
88
+ # BrainEncoder model (must match training architecture exactly)
89
+ # ============================================================
90
+ class BrainEncoder(nn.Module):
91
+ def __init__(self, input_dim=4096, n_voxels=15724, hidden_dims=None, dropout=0.3, n_rois=24):
92
+ super().__init__()
93
+ if hidden_dims is None:
94
+ hidden_dims = [2048, 2048, 1024]
95
+
96
+ self.input_dim = input_dim
97
+ self.n_voxels = n_voxels
98
+ self.n_rois = n_rois
99
+
100
+ layers = []
101
+ prev_dim = input_dim
102
+ for h_dim in hidden_dims:
103
+ layers.extend([
104
+ nn.Linear(prev_dim, h_dim),
105
+ nn.BatchNorm1d(h_dim),
106
+ nn.GELU(),
107
+ nn.Dropout(dropout),
108
+ ])
109
+ prev_dim = h_dim
110
+ self.backbone = nn.Sequential(*layers)
111
+
112
+ self.general_head = nn.Linear(hidden_dims[-1], n_voxels)
113
+
114
+ self.roi_attention = nn.ModuleDict()
115
+ self.roi_heads = nn.ModuleDict()
116
+ self.network_names = ["early_visual", "body_selective", "face_selective",
117
+ "place_selective", "word_selective"]
118
+
119
+ for net_name in self.network_names:
120
+ self.roi_attention[net_name] = nn.Sequential(
121
+ nn.Linear(hidden_dims[-1], 256),
122
+ nn.GELU(),
123
+ nn.Linear(256, hidden_dims[-1]),
124
+ nn.Sigmoid(),
125
+ )
126
+ self.roi_heads[net_name] = nn.Linear(hidden_dims[-1], n_voxels)
127
+
128
+ self.register_buffer('roi_mask', torch.zeros(n_voxels, dtype=torch.long))
129
+
130
+ def set_roi_assignments(self, annot):
131
+ for net_idx, (net_name, roi_ids) in enumerate(FUNCTIONAL_NETWORKS.items()):
132
+ for roi_id in roi_ids:
133
+ mask = (annot == roi_id)
134
+ if len(mask) <= self.n_voxels:
135
+ self.roi_mask[:len(mask)][mask[:self.n_voxels]] = net_idx + 1
136
+
137
+ def forward(self, x, return_intermediates=False):
138
+ intermediates = {}
139
+ backbone_out = self.backbone(x)
140
+ intermediates['backbone'] = backbone_out.detach()
141
+
142
+ pred = self.general_head(backbone_out)
143
+ intermediates['general_pred'] = pred.detach()
144
+
145
+ for net_idx, net_name in enumerate(self.network_names):
146
+ if net_name in self.roi_attention:
147
+ attn = self.roi_attention[net_name](backbone_out)
148
+ weighted = backbone_out * attn
149
+ roi_pred = self.roi_heads[net_name](weighted)
150
+ mask = (self.roi_mask == net_idx + 1)
151
+ if mask.any():
152
+ pred[:, mask] = roi_pred[:, mask]
153
+ intermediates[f'roi_{net_name}'] = roi_pred.detach()
154
+
155
+ if return_intermediates:
156
+ return pred, intermediates
157
+ return pred
158
+
159
+
160
+ # ============================================================
161
+ # Model Manager - loads and caches models
162
+ # ============================================================
163
+ class ModelManager:
164
+ def __init__(self):
165
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
166
+ self.brain_encoder = None
167
+ self.ridge_model = None
168
+ self.clip_model = None
169
+ self.clip_processor = None
170
+ self.roi_annotations = None
171
+ self.config = None
172
+ self._loaded = False
173
+
174
+ def load(self):
175
+ if self._loaded:
176
+ return
177
+
178
+ from huggingface_hub import hf_hub_download
179
+
180
+ logger.info(f"Loading models from {MODEL_REPO}...")
181
+
182
+ # Load config
183
+ try:
184
+ config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json")
185
+ with open(config_path) as f:
186
+ self.config = json.load(f)
187
+ logger.info(f"Config loaded: {self.config.get('architecture', 'unknown')}")
188
+ except Exception as e:
189
+ logger.warning(f"Config load failed: {e}")
190
+ self.config = {}
191
+
192
+ # Load ROI annotations
193
+ try:
194
+ annot_path = hf_hub_download(repo_id=MODEL_REPO, filename="roi_annotations.npy")
195
+ self.roi_annotations = np.load(annot_path).flatten()
196
+ logger.info(f"ROI annotations: {self.roi_annotations.shape}")
197
+ except Exception as e:
198
+ logger.warning(f"ROI annotations load failed: {e}")
199
+
200
+ # Load brain encoder
201
+ try:
202
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt")
203
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
204
+
205
+ model_config = checkpoint.get('config', {})
206
+ self.brain_encoder = BrainEncoder(
207
+ input_dim=model_config.get('input_dim', 4096),
208
+ n_voxels=model_config.get('n_voxels', 15724),
209
+ hidden_dims=model_config.get('hidden_dims', [2048, 2048, 1024]),
210
+ dropout=model_config.get('dropout', 0.3),
211
+ )
212
+ self.brain_encoder.load_state_dict(checkpoint['model_state_dict'])
213
+ self.brain_encoder.to(self.device).eval()
214
+
215
+ if self.roi_annotations is not None:
216
+ self.brain_encoder.set_roi_assignments(self.roi_annotations)
217
+
218
+ logger.info("Brain encoder loaded successfully")
219
+ except Exception as e:
220
+ logger.error(f"Brain encoder load failed: {e}")
221
+ raise
222
+
223
+ # Load ridge model
224
+ try:
225
+ ridge_path = hf_hub_download(repo_id=MODEL_REPO, filename="ridge_model.pkl")
226
+ with open(ridge_path, 'rb') as f:
227
+ self.ridge_model = pickle.load(f)
228
+ logger.info("Ridge model loaded successfully")
229
+ except Exception as e:
230
+ logger.warning(f"Ridge model load failed: {e}")
231
+
232
+ # Load CLIP
233
+ try:
234
+ from transformers import CLIPModel, CLIPProcessor
235
+ self.clip_model = CLIPModel.from_pretrained(
236
+ "openai/clip-vit-large-patch14",
237
+ torch_dtype=torch.float32,
238
+ ).to(self.device).eval()
239
+ self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
240
+ logger.info("CLIP model loaded")
241
+ except Exception as e:
242
+ logger.error(f"CLIP load failed: {e}")
243
+ raise
244
+
245
+ self._loaded = True
246
+ logger.info("All models loaded successfully!")
247
+
248
+ def extract_features(self, image=None, text=None, audio=None):
249
+ """Extract multimodal CLIP features."""
250
+ features_dict = {}
251
+
252
+ if image is not None:
253
+ if isinstance(image, np.ndarray):
254
+ image = Image.fromarray(image)
255
+
256
+ inputs = self.clip_processor(images=image, return_tensors="pt")
257
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
258
+
259
+ with torch.no_grad():
260
+ outputs = self.clip_model.vision_model(**inputs, output_hidden_states=True)
261
+
262
+ cls_features = outputs.last_hidden_state[:, 0, :]
263
+ projected = self.clip_model.visual_projection(cls_features)
264
+
265
+ hidden_concat = []
266
+ for layer_idx in [6, 12, 18, 23]:
267
+ h = outputs.hidden_states[layer_idx][:, 0, :]
268
+ hidden_concat.append(h)
269
+ multi_layer = torch.cat(hidden_concat, dim=-1)
270
+
271
+ features_dict['image_projected'] = projected.cpu().float()
272
+ features_dict['image_multi_layer'] = multi_layer.cpu().float()
273
+ features_dict['modality'] = 'image'
274
+
275
+ if text is not None and text.strip():
276
+ inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True, truncation=True)
277
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
278
+
279
+ with torch.no_grad():
280
+ text_outputs = self.clip_model.text_model(**inputs)
281
+ pooled = text_outputs.pooler_output
282
+ projected = self.clip_model.text_projection(pooled)
283
+
284
+ # For text, repeat the projected features to match multi-layer dim
285
+ # Text goes through the same brain encoder by tiling to 4096
286
+ text_multi = projected.repeat(1, 4096 // projected.shape[1] + 1)[:, :4096]
287
+
288
+ features_dict['text_projected'] = projected.cpu().float()
289
+ features_dict['text_multi_layer'] = text_multi.cpu().float()
290
+ if 'modality' not in features_dict:
291
+ features_dict['modality'] = 'text'
292
+ else:
293
+ features_dict['modality'] = 'image+text'
294
+
295
+ if audio is not None:
296
+ # Convert audio to spectrogram image for CLIP processing
297
+ sr, audio_data = audio if isinstance(audio, tuple) else (16000, audio)
298
+
299
+ if len(audio_data.shape) > 1:
300
+ audio_data = audio_data.mean(axis=1)
301
+ audio_data = audio_data.astype(np.float32)
302
+
303
+ # Create mel spectrogram visualization
304
+ import matplotlib
305
+ matplotlib.use('Agg')
306
+ import matplotlib.pyplot as plt
307
+
308
+ fig, ax = plt.subplots(1, 1, figsize=(4, 4))
309
+
310
+ # Simple spectrogram using STFT
311
+ n_fft = min(1024, len(audio_data))
312
+ hop_length = n_fft // 4
313
+
314
+ if len(audio_data) > n_fft:
315
+ # Manual STFT
316
+ n_frames = (len(audio_data) - n_fft) // hop_length + 1
317
+ spec = np.zeros((n_fft // 2 + 1, n_frames))
318
+ window = np.hanning(n_fft)
319
+
320
+ for i in range(n_frames):
321
+ start = i * hop_length
322
+ frame = audio_data[start:start + n_fft] * window
323
+ fft = np.fft.rfft(frame)
324
+ spec[:, i] = np.abs(fft)
325
+
326
+ spec_db = 20 * np.log10(spec + 1e-10)
327
+ ax.imshow(spec_db, aspect='auto', origin='lower', cmap='viridis')
328
+ else:
329
+ ax.plot(audio_data[:1000])
330
+
331
+ ax.set_title("Audio Spectrogram")
332
+ ax.axis('off')
333
+
334
+ fig.canvas.draw()
335
+
336
+ # Convert to image
337
+ buf = fig.canvas.buffer_rgba()
338
+ spec_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf).convert('RGB')
339
+ plt.close(fig)
340
+
341
+ # Process through CLIP as image
342
+ inputs = self.clip_processor(images=spec_img, return_tensors="pt")
343
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
344
+
345
+ with torch.no_grad():
346
+ outputs = self.clip_model.vision_model(**inputs, output_hidden_states=True)
347
+ cls_features = outputs.last_hidden_state[:, 0, :]
348
+ projected = self.clip_model.visual_projection(cls_features)
349
+
350
+ hidden_concat = []
351
+ for layer_idx in [6, 12, 18, 23]:
352
+ h = outputs.hidden_states[layer_idx][:, 0, :]
353
+ hidden_concat.append(h)
354
+ multi_layer = torch.cat(hidden_concat, dim=-1)
355
+
356
+ features_dict['audio_projected'] = projected.cpu().float()
357
+ features_dict['audio_multi_layer'] = multi_layer.cpu().float()
358
+
359
+ if features_dict.get('modality') is None:
360
+ features_dict['modality'] = 'audio'
361
+ else:
362
+ features_dict['modality'] = features_dict['modality'] + '+audio'
363
+
364
+ return features_dict
365
+
366
+ def predict_brain_activity(self, features_dict):
367
+ """Run brain encoder forward pass."""
368
+ # Determine which features to use
369
+ if 'image_multi_layer' in features_dict:
370
+ input_features = features_dict['image_multi_layer']
371
+ elif 'text_multi_layer' in features_dict:
372
+ input_features = features_dict['text_multi_layer']
373
+ elif 'audio_multi_layer' in features_dict:
374
+ input_features = features_dict['audio_multi_layer']
375
+ else:
376
+ raise ValueError("No features available for prediction")
377
+
378
+ # If multimodal, average features
379
+ all_modality_features = []
380
+ for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']:
381
+ if key in features_dict:
382
+ all_modality_features.append(features_dict[key])
383
+
384
+ if len(all_modality_features) > 1:
385
+ input_features = torch.mean(torch.stack(all_modality_features), dim=0)
386
+
387
+ input_features = input_features.to(self.device)
388
+
389
+ # Forward pass through brain encoder
390
+ with torch.no_grad():
391
+ predictions, intermediates = self.brain_encoder(input_features, return_intermediates=True)
392
+
393
+ pred_np = predictions.cpu().numpy().flatten()
394
+
395
+ # Compute modality contributions
396
+ modality_contributions = {}
397
+ for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']:
398
+ if key in features_dict:
399
+ modality_name = key.split('_')[0]
400
+ feat = features_dict[key].to(self.device)
401
+ with torch.no_grad():
402
+ single_pred = self.brain_encoder(feat)
403
+ modality_contributions[modality_name] = single_pred.cpu().numpy().flatten()
404
+
405
+ # Compute uncertainty via dropout MC
406
+ self.brain_encoder.train() # Enable dropout
407
+ mc_predictions = []
408
+ for _ in range(10):
409
+ with torch.no_grad():
410
+ mc_pred = self.brain_encoder(input_features)
411
+ mc_predictions.append(mc_pred.cpu().numpy().flatten())
412
+ self.brain_encoder.eval()
413
+
414
+ mc_predictions = np.array(mc_predictions)
415
+ uncertainty = np.std(mc_predictions, axis=0)
416
+
417
+ # Compute ROI summaries
418
+ roi_summary = self._compute_roi_summary(pred_np, uncertainty)
419
+
420
+ # Validation checks
421
+ warnings = self._validate_predictions(pred_np)
422
+
423
+ result = {
424
+ 'predictions': pred_np,
425
+ 'uncertainty': uncertainty,
426
+ 'roi_summary': roi_summary,
427
+ 'modality_contributions': modality_contributions,
428
+ 'modality': features_dict.get('modality', 'unknown'),
429
+ 'intermediates': {k: v.cpu().numpy() if torch.is_tensor(v) else v
430
+ for k, v in intermediates.items()},
431
+ 'warnings': warnings,
432
+ 'timestamp': datetime.now().isoformat(),
433
+ }
434
+
435
+ return result
436
+
437
+ def _compute_roi_summary(self, predictions, uncertainty):
438
+ """Compute per-ROI activation summaries."""
439
+ if self.roi_annotations is None:
440
+ return {}
441
+
442
+ annot = self.roi_annotations
443
+ n_voxels = len(predictions)
444
+
445
+ roi_summary = {}
446
+
447
+ for roi_id, roi_name in ROI_NAMES.items():
448
+ mask = (annot[:n_voxels] == roi_id) if len(annot) >= n_voxels else np.zeros(n_voxels, dtype=bool)
449
+
450
+ if mask.sum() == 0:
451
+ continue
452
+
453
+ roi_activations = predictions[mask]
454
+ roi_uncertainty = uncertainty[mask]
455
+
456
+ roi_summary[roi_name] = {
457
+ 'mean_activation': float(np.mean(roi_activations)),
458
+ 'max_activation': float(np.max(roi_activations)),
459
+ 'min_activation': float(np.min(roi_activations)),
460
+ 'std_activation': float(np.std(roi_activations)),
461
+ 'mean_uncertainty': float(np.mean(roi_uncertainty)),
462
+ 'n_voxels': int(mask.sum()),
463
+ 'abs_mean': float(np.mean(np.abs(roi_activations))),
464
+ 'known_function': ROI_FUNCTIONS.get(roi_name, "Unknown"),
465
+ }
466
+
467
+ return roi_summary
468
+
469
+ def _validate_predictions(self, predictions):
470
+ """Validation safeguards."""
471
+ warnings = []
472
+
473
+ if np.std(predictions) < 1e-6:
474
+ warnings.append("⚠️ CONSTANT OUTPUT DETECTED: All voxels have near-identical values")
475
+
476
+ if np.any(np.isnan(predictions)):
477
+ warnings.append("⚠️ NaN VALUES DETECTED in predictions")
478
+
479
+ if np.any(np.isinf(predictions)):
480
+ warnings.append("⚠️ Infinite VALUES DETECTED in predictions")
481
+
482
+ if np.max(np.abs(predictions)) > 50:
483
+ warnings.append(f"⚠️ Unusually large activations detected (max |activation| = {np.max(np.abs(predictions)):.2f})")
484
+
485
+ return warnings
486
+
487
+
488
+ # ============================================================
489
+ # Grounded Q&A System
490
+ # ============================================================
491
+ class GroundedQA:
492
+ """
493
+ RAG-grounded Q&A system.
494
+ The LLM is an INTERPRETER - it only explains model predictions.
495
+ It does NOT generate independent neuroscience claims.
496
+ """
497
+
498
+ def __init__(self):
499
+ self.inference_client = None
500
+ self._init_client()
501
+
502
+ def _init_client(self):
503
+ try:
504
+ from huggingface_hub import InferenceClient
505
+ self.inference_client = InferenceClient(
506
+ provider="hf-inference",
507
+ api_key=os.environ.get("HF_TOKEN", ""),
508
+ )
509
+ logger.info("HF Inference Client initialized")
510
+ except Exception as e:
511
+ logger.warning(f"Inference client init failed: {e}")
512
+
513
+ def build_context(self, brain_result):
514
+ """Build structured context from model predictions for LLM grounding."""
515
+
516
+ roi_summary = brain_result.get('roi_summary', {})
517
+ modality = brain_result.get('modality', 'unknown')
518
+ warnings = brain_result.get('warnings', [])
519
+ modality_contributions = brain_result.get('modality_contributions', {})
520
+
521
+ # Sort ROIs by absolute mean activation
522
+ sorted_rois = sorted(
523
+ roi_summary.items(),
524
+ key=lambda x: abs(x[1]['abs_mean']),
525
+ reverse=True
526
+ )
527
+
528
+ # Top activated regions
529
+ top_regions = []
530
+ for roi_name, data in sorted_rois[:10]:
531
+ top_regions.append(
532
+ f"- {roi_name}: mean_activation={data['mean_activation']:.4f}, "
533
+ f"abs_mean={data['abs_mean']:.4f}, uncertainty={data['mean_uncertainty']:.4f}, "
534
+ f"n_voxels={data['n_voxels']}"
535
+ )
536
+
537
+ # Network-level summaries
538
+ network_summaries = {}
539
+ for net_name, roi_ids in FUNCTIONAL_NETWORKS.items():
540
+ roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES]
541
+ activations = []
542
+ for rn in roi_names_in_net:
543
+ if rn in roi_summary:
544
+ activations.append(roi_summary[rn]['abs_mean'])
545
+
546
+ if activations:
547
+ network_summaries[net_name] = {
548
+ 'mean_abs_activation': np.mean(activations),
549
+ 'max_abs_activation': np.max(activations),
550
+ 'function': NETWORK_FUNCTIONS.get(net_name, ""),
551
+ }
552
+
553
+ sorted_networks = sorted(
554
+ network_summaries.items(),
555
+ key=lambda x: x[1]['mean_abs_activation'],
556
+ reverse=True
557
+ )
558
+
559
+ # Modality contributions
560
+ modality_info = ""
561
+ if modality_contributions:
562
+ modality_info = "\n## Modality Contributions\n"
563
+ for mod_name, mod_pred in modality_contributions.items():
564
+ modality_info += f"- {mod_name}: mean_abs_activation={np.mean(np.abs(mod_pred)):.4f}, std={np.std(mod_pred):.4f}\n"
565
+
566
+ # Global prediction stats
567
+ predictions = brain_result['predictions']
568
+ global_stats = (
569
+ f"- Total voxels predicted: {len(predictions)}\n"
570
+ f"- Global mean activation: {np.mean(predictions):.4f}\n"
571
+ f"- Global std: {np.std(predictions):.4f}\n"
572
+ f"- Global range: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]\n"
573
+ f"- Mean uncertainty: {np.mean(brain_result['uncertainty']):.4f}\n"
574
+ )
575
+
576
+ context = f"""## Brain Activity Prediction Summary
577
+ Input modality: {modality}
578
+
579
+ ## Global Statistics
580
+ {global_stats}
581
+
582
+ ## Top 10 Activated Brain Regions (by absolute activation strength)
583
+ {chr(10).join(top_regions)}
584
+
585
+ ## Functional Network Activations (ranked by strength)
586
+ """
587
+ for net_name, net_data in sorted_networks:
588
+ context += (
589
+ f"- {net_name}: mean_abs={net_data['mean_abs_activation']:.4f}, "
590
+ f"max_abs={net_data['max_abs_activation']:.4f}\n"
591
+ f" Known function: {net_data['function']}\n"
592
+ )
593
+
594
+ context += modality_info
595
+
596
+ if warnings:
597
+ context += "\n## Warnings\n"
598
+ for w in warnings:
599
+ context += f"- {w}\n"
600
+
601
+ # ROI functional labels
602
+ context += "\n## ROI Functional Reference\n"
603
+ for roi_name in [r[0] for r in sorted_rois[:10]]:
604
+ if roi_name in ROI_FUNCTIONS:
605
+ context += f"- {roi_name}: {ROI_FUNCTIONS[roi_name]}\n"
606
+
607
+ return context
608
+
609
+ def answer(self, question, brain_result):
610
+ """Answer a question grounded in model predictions."""
611
+
612
+ context = self.build_context(brain_result)
613
+
614
+ system_prompt = """You are a neuroscience interpreter for a brain encoding model.
615
+ Your role is STRICTLY to interpret and explain the model's predicted brain activity patterns.
616
+
617
+ CRITICAL RULES:
618
+ 1. ONLY reference data provided in the context below. Never invent neuroscience claims.
619
+ 2. Always distinguish between:
620
+ - "Predicted activation" (what the model outputs)
621
+ - "Known neuroscience association" (established findings about brain regions)
622
+ - "Possible interpretation" (your inference connecting the two)
623
+ 3. Include uncertainty statements. Use phrases like "the model predicts", "this is consistent with", "one possible interpretation is"
624
+ 4. NEVER make definitive claims about emotions, consciousness, or behavior from brain activity alone.
625
+ 5. Always cite specific regions, activation values, and confidence levels from the context.
626
+ 6. If the question cannot be answered from the provided data, say so explicitly.
627
+ 7. Keep answers concise but informative (2-4 paragraphs max).
628
+
629
+ You are an INTERPRETER of model outputs, not an independent neuroscience oracle."""
630
+
631
+ user_prompt = f"""## Model Prediction Context
632
+ {context}
633
+
634
+ ## User Question
635
+ {question}
636
+
637
+ Please answer based ONLY on the model prediction data above. Cite specific regions and values."""
638
+
639
+ if self.inference_client is None:
640
+ return self._fallback_answer(question, brain_result, context)
641
+
642
+ try:
643
+ response = self.inference_client.chat.completions.create(
644
+ model="Qwen/Qwen2.5-72B-Instruct",
645
+ messages=[
646
+ {"role": "system", "content": system_prompt},
647
+ {"role": "user", "content": user_prompt},
648
+ ],
649
+ max_tokens=800,
650
+ temperature=0.3,
651
+ )
652
+ answer = response.choices[0].message.content
653
+
654
+ # Add grounding footer
655
+ answer += "\n\n---\n*This interpretation is based on model predictions with "
656
+ mean_unc = np.mean(brain_result['uncertainty'])
657
+ answer += f"mean uncertainty={mean_unc:.4f}. "
658
+ answer += "Predictions are from a brain encoder trained on NSD (Natural Scenes Dataset) fMRI data.*"
659
+
660
+ return answer
661
+
662
+ except Exception as e:
663
+ logger.warning(f"LLM inference failed: {e}")
664
+ return self._fallback_answer(question, brain_result, context)
665
+
666
+ def _fallback_answer(self, question, brain_result, context):
667
+ """Structured fallback when LLM is unavailable."""
668
+ roi_summary = brain_result.get('roi_summary', {})
669
+ sorted_rois = sorted(
670
+ roi_summary.items(),
671
+ key=lambda x: abs(x[1]['abs_mean']),
672
+ reverse=True
673
+ )
674
+
675
+ answer = "## Brain Activity Interpretation\n\n"
676
+ answer += f"**Input modality:** {brain_result.get('modality', 'unknown')}\n\n"
677
+
678
+ answer += "### Top Activated Regions\n"
679
+ for roi_name, data in sorted_rois[:5]:
680
+ answer += (
681
+ f"- **{roi_name}** (activation={data['mean_activation']:.4f}, "
682
+ f"uncertainty={data['mean_uncertainty']:.4f}): "
683
+ f"{ROI_FUNCTIONS.get(roi_name, 'Unknown function')}\n"
684
+ )
685
+
686
+ answer += "\n### Network-Level Summary\n"
687
+ for net_name, roi_ids in FUNCTIONAL_NETWORKS.items():
688
+ roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES]
689
+ activations = [roi_summary[rn]['abs_mean'] for rn in roi_names_in_net if rn in roi_summary]
690
+ if activations:
691
+ mean_act = np.mean(activations)
692
+ answer += f"- **{net_name}**: mean_abs_activation={mean_act:.4f} β€” {NETWORK_FUNCTIONS.get(net_name, '')}\n"
693
+
694
+ answer += f"\n*Note: LLM interpretation unavailable. Showing structured prediction summary. "
695
+ answer += f"Mean uncertainty: {np.mean(brain_result['uncertainty']):.4f}*"
696
+
697
+ return answer
698
+
699
+
700
+ # ============================================================
701
+ # Transparency Logger
702
+ # ============================================================
703
+ class TransparencyLogger:
704
+ """Logs all inputs, intermediates, and outputs for traceability."""
705
+
706
+ def __init__(self):
707
+ self.logs = []
708
+
709
+ def log_inference(self, inputs, features_dict, brain_result, qa_answer=None):
710
+ entry = {
711
+ 'timestamp': datetime.now().isoformat(),
712
+ 'inputs': {
713
+ 'has_image': inputs.get('image') is not None,
714
+ 'has_text': inputs.get('text') is not None and inputs.get('text', '').strip() != '',
715
+ 'has_audio': inputs.get('audio') is not None,
716
+ 'text_content': inputs.get('text', '')[:200],
717
+ },
718
+ 'features': {
719
+ 'modality': features_dict.get('modality', 'unknown'),
720
+ 'feature_norms': {},
721
+ },
722
+ 'predictions': {
723
+ 'n_voxels': len(brain_result['predictions']),
724
+ 'pred_mean': float(np.mean(brain_result['predictions'])),
725
+ 'pred_std': float(np.std(brain_result['predictions'])),
726
+ 'pred_range': [float(np.min(brain_result['predictions'])),
727
+ float(np.max(brain_result['predictions']))],
728
+ 'uncertainty_mean': float(np.mean(brain_result['uncertainty'])),
729
+ },
730
+ 'roi_summary_sent_to_llm': list(brain_result.get('roi_summary', {}).keys()),
731
+ 'warnings': brain_result.get('warnings', []),
732
+ 'qa_answer_length': len(qa_answer) if qa_answer else 0,
733
+ }
734
+
735
+ # Feature norms
736
+ for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']:
737
+ if key in features_dict:
738
+ entry['features']['feature_norms'][key] = float(features_dict[key].norm().item())
739
+
740
+ self.logs.append(entry)
741
+ return entry
742
+
743
+ def get_log_text(self):
744
+ return json.dumps(self.logs[-5:], indent=2, default=str)
745
+
746
+
747
+ # ============================================================
748
+ # Visualization helpers
749
+ # ============================================================
750
+ def create_brain_activation_plot(brain_result, roi_annotations):
751
+ """Create brain activation visualization."""
752
+ import plotly.graph_objects as go
753
+ from plotly.subplots import make_subplots
754
+
755
+ roi_summary = brain_result.get('roi_summary', {})
756
+
757
+ if not roi_summary:
758
+ fig = go.Figure()
759
+ fig.add_annotation(text="No ROI data available", x=0.5, y=0.5)
760
+ return fig
761
+
762
+ # Create multi-panel figure
763
+ fig = make_subplots(
764
+ rows=2, cols=2,
765
+ subplot_titles=(
766
+ "ROI Activation Strengths",
767
+ "Functional Network Summary",
768
+ "Activation Uncertainty",
769
+ "Activation Distribution",
770
+ ),
771
+ specs=[
772
+ [{"type": "bar"}, {"type": "bar"}],
773
+ [{"type": "bar"}, {"type": "histogram"}],
774
+ ]
775
+ )
776
+
777
+ # Panel 1: ROI activations
778
+ sorted_rois = sorted(roi_summary.items(), key=lambda x: abs(x[1]['abs_mean']), reverse=True)[:15]
779
+ roi_names = [r[0] for r in sorted_rois]
780
+ roi_activations = [r[1]['mean_activation'] for r in sorted_rois]
781
+ roi_colors = []
782
+ for r in sorted_rois:
783
+ name = r[0]
784
+ for net_name, roi_ids in FUNCTIONAL_NETWORKS.items():
785
+ roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES]
786
+ if name in roi_names_in_net:
787
+ color_map = {
788
+ "early_visual": "#4CAF50",
789
+ "body_selective": "#FF9800",
790
+ "face_selective": "#E91E63",
791
+ "place_selective": "#2196F3",
792
+ "word_selective": "#9C27B0",
793
+ }
794
+ roi_colors.append(color_map.get(net_name, "#666"))
795
+ break
796
+ else:
797
+ roi_colors.append("#666")
798
+
799
+ fig.add_trace(
800
+ go.Bar(x=roi_names, y=roi_activations, marker_color=roi_colors, name="Activation"),
801
+ row=1, col=1
802
+ )
803
+
804
+ # Panel 2: Network summary
805
+ net_names = []
806
+ net_activations = []
807
+ net_colors_list = []
808
+ color_map = {
809
+ "early_visual": "#4CAF50",
810
+ "body_selective": "#FF9800",
811
+ "face_selective": "#E91E63",
812
+ "place_selective": "#2196F3",
813
+ "word_selective": "#9C27B0",
814
+ }
815
+
816
+ for net_name, roi_ids in FUNCTIONAL_NETWORKS.items():
817
+ roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES]
818
+ activations = [roi_summary[rn]['abs_mean'] for rn in roi_names_in_net if rn in roi_summary]
819
+ if activations:
820
+ net_names.append(net_name.replace("_", " ").title())
821
+ net_activations.append(np.mean(activations))
822
+ net_colors_list.append(color_map.get(net_name, "#666"))
823
+
824
+ fig.add_trace(
825
+ go.Bar(x=net_names, y=net_activations, marker_color=net_colors_list, name="Network"),
826
+ row=1, col=2
827
+ )
828
+
829
+ # Panel 3: Uncertainty
830
+ roi_uncertainty = [r[1]['mean_uncertainty'] for r in sorted_rois]
831
+ fig.add_trace(
832
+ go.Bar(x=roi_names, y=roi_uncertainty, marker_color='rgba(255,0,0,0.5)', name="Uncertainty"),
833
+ row=2, col=1
834
+ )
835
+
836
+ # Panel 4: Distribution
837
+ predictions = brain_result['predictions']
838
+ fig.add_trace(
839
+ go.Histogram(x=predictions[::10], nbinsx=50, name="Activations", marker_color='#4CAF50'),
840
+ row=2, col=2
841
+ )
842
+
843
+ fig.update_layout(
844
+ height=700,
845
+ showlegend=False,
846
+ title_text="Brain Activity Predictions",
847
+ template="plotly_white",
848
+ )
849
+
850
+ return fig
851
+
852
+
853
+ def create_modality_contribution_plot(brain_result):
854
+ """Create modality contribution visualization."""
855
+ import plotly.graph_objects as go
856
+
857
+ contributions = brain_result.get('modality_contributions', {})
858
+
859
+ if len(contributions) <= 1:
860
+ fig = go.Figure()
861
+ fig.add_annotation(text="Single modality input - no comparison available", x=0.5, y=0.5)
862
+ return fig
863
+
864
+ fig = go.Figure()
865
+
866
+ for mod_name, mod_pred in contributions.items():
867
+ # Show distribution of activations per modality
868
+ fig.add_trace(go.Histogram(
869
+ x=mod_pred[::10],
870
+ name=mod_name.capitalize(),
871
+ opacity=0.6,
872
+ nbinsx=50,
873
+ ))
874
+
875
+ fig.update_layout(
876
+ title="Modality Contributions to Brain Activity",
877
+ xaxis_title="Predicted Activation",
878
+ yaxis_title="Count",
879
+ barmode='overlay',
880
+ template="plotly_white",
881
+ height=400,
882
+ )
883
+
884
+ return fig
885
+
886
+
887
+ # ============================================================
888
+ # Gradio Application
889
+ # ============================================================
890
+ def build_gradio_app():
891
+ import gradio as gr
892
+
893
+ # Global state
894
+ manager = ModelManager()
895
+ qa_system = GroundedQA()
896
+ transparency_log = TransparencyLogger()
897
+
898
+ current_result = {"value": None}
899
+
900
+ def initialize():
901
+ try:
902
+ manager.load()
903
+ return "βœ… Models loaded successfully!"
904
+ except Exception as e:
905
+ return f"❌ Error loading models: {e}"
906
+
907
+ def process_input(image, text, audio):
908
+ """Main inference pipeline."""
909
+ if not manager._loaded:
910
+ manager.load()
911
+
912
+ if image is None and (text is None or text.strip() == '') and audio is None:
913
+ return "Please provide at least one input (image, text, or audio).", None, None, ""
914
+
915
+ try:
916
+ # Step 1: Extract features
917
+ features = manager.extract_features(image=image, text=text, audio=audio)
918
+
919
+ # Step 2: Predict brain activity
920
+ result = manager.predict_brain_activity(features)
921
+ current_result["value"] = result
922
+
923
+ # Step 3: Create visualizations
924
+ brain_plot = create_brain_activation_plot(result, manager.roi_annotations)
925
+ modality_plot = create_modality_contribution_plot(result)
926
+
927
+ # Step 4: Log for transparency
928
+ log_entry = transparency_log.log_inference(
929
+ {'image': image, 'text': text, 'audio': audio},
930
+ features, result
931
+ )
932
+
933
+ # Summary text
934
+ roi_summary = result.get('roi_summary', {})
935
+ sorted_rois = sorted(roi_summary.items(), key=lambda x: abs(x[1]['abs_mean']), reverse=True)
936
+
937
+ summary = f"**Modality:** {result['modality']}\n"
938
+ summary += f"**Voxels predicted:** {len(result['predictions'])}\n"
939
+ summary += f"**Mean uncertainty:** {np.mean(result['uncertainty']):.4f}\n\n"
940
+ summary += "**Top 5 Activated Regions:**\n"
941
+ for roi_name, data in sorted_rois[:5]:
942
+ summary += f"- {roi_name}: {data['mean_activation']:.4f} (Β±{data['mean_uncertainty']:.4f})\n"
943
+
944
+ if result['warnings']:
945
+ summary += "\n**Warnings:**\n"
946
+ for w in result['warnings']:
947
+ summary += f"- {w}\n"
948
+
949
+ return summary, brain_plot, modality_plot, json.dumps(log_entry, indent=2, default=str)
950
+
951
+ except Exception as e:
952
+ import traceback
953
+ return f"Error: {e}\n{traceback.format_exc()}", None, None, ""
954
+
955
+ def ask_question(question, history):
956
+ """Q&A with grounded interpretation."""
957
+ if current_result["value"] is None:
958
+ history = history or []
959
+ history.append({"role": "user", "content": question})
960
+ history.append({"role": "assistant", "content": "Please run an inference first (provide an input in the Stimulus tab) before asking questions."})
961
+ return history, ""
962
+
963
+ history = history or []
964
+ history.append({"role": "user", "content": question})
965
+
966
+ answer = qa_system.answer(question, current_result["value"])
967
+ history.append({"role": "assistant", "content": answer})
968
+
969
+ # Log Q&A
970
+ transparency_log.log_inference(
971
+ {'text': question},
972
+ {'modality': 'qa'},
973
+ current_result["value"],
974
+ qa_answer=answer,
975
+ )
976
+
977
+ return history, ""
978
+
979
+ def get_transparency_log():
980
+ return transparency_log.get_log_text()
981
+
982
+ # Build UI
983
+ with gr.Blocks(title="Multimodal Brain Encoder") as demo:
984
+ gr.Markdown("""
985
+ # 🧠 Multimodal Brain Encoder
986
+
987
+ **A real brain encoding model trained on the Natural Scenes Dataset (NSD)**
988
+
989
+ This system predicts brain activity (fMRI voxel responses) from multimodal inputs using:
990
+ - **CLIP ViT-L/14** for feature extraction (multi-layer: layers 6, 12, 18, 24)
991
+ - **Deep Brain Encoder** with ROI-specific attention heads (trained on NSD subj01)
992
+ - **Ridge Regression** baseline (Algonauts 2023 recipe)
993
+ - **Grounded LLM Q&A** that only interprets model predictions
994
+
995
+ All predictions are from real model forward passes with learned weights.
996
+ """)
997
+
998
+ status = gr.Textbox(label="Status", value="Click 'Load Models' to initialize")
999
+ load_btn = gr.Button("πŸ”„ Load Models", variant="primary")
1000
+ load_btn.click(fn=initialize, outputs=status)
1001
+
1002
+ with gr.Tabs():
1003
+ # Tab 1: Input & Prediction
1004
+ with gr.Tab("🎯 Stimulus Input & Brain Prediction"):
1005
+ with gr.Row():
1006
+ with gr.Column(scale=1):
1007
+ image_input = gr.Image(type="pil", label="Visual Stimulus (Image)")
1008
+ text_input = gr.Textbox(
1009
+ label="Text Input",
1010
+ placeholder="Enter a description or sentence...",
1011
+ lines=3,
1012
+ )
1013
+ audio_input = gr.Audio(type="numpy", label="Audio Input")
1014
+ predict_btn = gr.Button("🧠 Predict Brain Activity", variant="primary", size="lg")
1015
+
1016
+ with gr.Column(scale=2):
1017
+ summary_output = gr.Markdown(label="Prediction Summary")
1018
+ brain_plot = gr.Plot(label="Brain Activity Visualization")
1019
+ modality_plot = gr.Plot(label="Modality Contributions")
1020
+
1021
+ predict_btn.click(
1022
+ fn=process_input,
1023
+ inputs=[image_input, text_input, audio_input],
1024
+ outputs=[summary_output, brain_plot, modality_plot, gr.Textbox(visible=False)],
1025
+ )
1026
+
1027
+ # Tab 2: Q&A
1028
+ with gr.Tab("πŸ’¬ Grounded Q&A"):
1029
+ gr.Markdown("""
1030
+ ### Ask questions about the predicted brain activity
1031
+
1032
+ The LLM interpreter will answer based ONLY on:
1033
+ - Predicted activation maps and ROI summaries
1034
+ - Known functional labels from brain atlases
1035
+ - Modality attribution outputs
1036
+ - Uncertainty estimates
1037
+
1038
+ It will NOT make independent neuroscience claims.
1039
+ """)
1040
+
1041
+ chatbot = gr.Chatbot(
1042
+ type="messages",
1043
+ label="Brain Activity Q&A",
1044
+ height=400,
1045
+ )
1046
+
1047
+ with gr.Row():
1048
+ question_input = gr.Textbox(
1049
+ label="Your Question",
1050
+ placeholder="e.g., Which brain regions are most activated? What does the face-selective network response mean?",
1051
+ scale=4,
1052
+ )
1053
+ ask_btn = gr.Button("Ask", variant="primary", scale=1)
1054
+
1055
+ ask_btn.click(
1056
+ fn=ask_question,
1057
+ inputs=[question_input, chatbot],
1058
+ outputs=[chatbot, question_input],
1059
+ )
1060
+ question_input.submit(
1061
+ fn=ask_question,
1062
+ inputs=[question_input, chatbot],
1063
+ outputs=[chatbot, question_input],
1064
+ )
1065
+
1066
+ gr.Markdown("""
1067
+ **Example questions:**
1068
+ - "What are the most activated brain regions for this input?"
1069
+ - "Is the face-selective network responding? What might that mean?"
1070
+ - "How confident is the model in these predictions?"
1071
+ - "How does the visual input differ from the text input in brain response?"
1072
+ - "What does high PPA activation suggest about this image?"
1073
+ """)
1074
+
1075
+ # Tab 3: Transparency Log
1076
+ with gr.Tab("πŸ“‹ Transparency Log"):
1077
+ gr.Markdown("### Full inference traceability log")
1078
+ gr.Markdown("Every inference is logged with inputs, features, predictions, and Q&A answers.")
1079
+
1080
+ log_output = gr.Code(language="json", label="Recent Logs")
1081
+ refresh_log_btn = gr.Button("πŸ”„ Refresh Log")
1082
+ refresh_log_btn.click(fn=get_transparency_log, outputs=log_output)
1083
+
1084
+ # Tab 4: Model Info
1085
+ with gr.Tab("ℹ️ Model Information"):
1086
+ gr.Markdown(f"""
1087
+ ### Architecture Details
1088
+
1089
+ | Component | Details |
1090
+ |-----------|---------|
1091
+ | Feature Extractor | CLIP ViT-L/14 (openai/clip-vit-large-patch14) |
1092
+ | Feature Layers | Layers 6, 12, 18, 24 (CLS tokens concatenated = 4096-dim) |
1093
+ | Brain Encoder | 4096 β†’ 2048 β†’ 2048 β†’ 1024 β†’ N_voxels |
1094
+ | Activations | GELU + BatchNorm + Dropout(0.3) |
1095
+ | ROI Heads | 5 functional network heads with learned attention |
1096
+ | Ridge Baseline | sklearn RidgeCV with 17 alphas (1e-2 to 1e6) |
1097
+ | Training Data | NSD subj01 (~8,859 train, ~300 val images) |
1098
+ | fMRI Resolution | 7T, ~15,724 voxels (NSD general cortical mask) |
1099
+ | Uncertainty | MC Dropout (10 forward passes) |
1100
+
1101
+ ### Brain Regions (24 ROIs from NSD)
1102
+
1103
+ | Network | Regions | Function |
1104
+ |---------|---------|----------|
1105
+ | Early Visual | V1v, V1d, V2v, V2d, V3v, V3d, hV4 | Basic visual processing |
1106
+ | Body Selective | EBA, FBA-1, FBA-2, mTL-bodies | Body/person perception |
1107
+ | Face Selective | OFA, FFA-1, FFA-2, mTL-faces, aTL-faces | Face recognition |
1108
+ | Place Selective | OPA, PPA, RSC | Scene/navigation |
1109
+ | Word Selective | OWFA, VWFA-1, VWFA-2, mfs-words, mTL-words | Reading/text |
1110
+
1111
+ ### References
1112
+ - Natural Scenes Dataset: Allen et al. 2022, Nature Neuroscience
1113
+ - Algonauts 2023: Gifford et al. 2023
1114
+ - CLIP: Radford et al. 2021
1115
+ - Model repo: [{MODEL_REPO}](https://huggingface.co/{MODEL_REPO})
1116
+ """)
1117
+
1118
+ return demo
1119
+
1120
+
1121
+ if __name__ == "__main__":
1122
+ demo = build_gradio_app()
1123
+ demo.launch(server_name="0.0.0.0", server_port=7860)