Spaces:
Sleeping
Sleeping
| """ | |
| Multimodal Brain Encoder - Gradio Application | |
| ============================================= | |
| Full end-to-end system: | |
| Input β CLIP Features β Brain Prediction β ROI Analysis β LLM Q&A β Visualization | |
| Uses real trained weights from NSD dataset. | |
| LLM is an INTERPRETER only - grounded in model predictions, not independent. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import logging | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from pathlib import Path | |
| from datetime import datetime | |
| from collections import OrderedDict | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # Configuration (must match training) | |
| # ============================================================ | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "ryu34/multimodal-brain-encoder") | |
| ROI_NAMES = { | |
| 1: "V1v", 2: "V1d", 3: "V2v", 4: "V2d", 5: "V3v", 6: "V3d", 7: "hV4", | |
| 8: "EBA", 9: "FBA-1", 10: "FBA-2", 11: "mTL-bodies", | |
| 12: "OFA", 13: "FFA-1", 14: "FFA-2", 15: "mTL-faces", 16: "aTL-faces", | |
| 17: "OPA", 18: "PPA", 19: "RSC", | |
| 20: "OWFA", 21: "VWFA-1", 22: "VWFA-2", 23: "mfs-words", 24: "mTL-words", | |
| } | |
| FUNCTIONAL_NETWORKS = { | |
| "early_visual": [1, 2, 3, 4, 5, 6, 7], | |
| "body_selective": [8, 9, 10, 11], | |
| "face_selective": [12, 13, 14, 15, 16], | |
| "place_selective": [17, 18, 19], | |
| "word_selective": [20, 21, 22, 23, 24], | |
| } | |
| # Known neuroscience associations for grounded Q&A | |
| ROI_FUNCTIONS = { | |
| "V1v": "Primary visual cortex (ventral); processes basic visual features: edges, orientations, spatial frequencies", | |
| "V1d": "Primary visual cortex (dorsal); processes basic visual features with dorsal visual stream emphasis", | |
| "V2v": "Secondary visual cortex (ventral); processes texture, figure-ground segregation", | |
| "V2d": "Secondary visual cortex (dorsal); processes contour and border ownership", | |
| "V3v": "Third visual area (ventral); contributes to form perception and shape processing", | |
| "V3d": "Third visual area (dorsal); processes dynamic form and motion boundaries", | |
| "hV4": "Human V4; processes color, pattern, moderate object features, texture discrimination", | |
| "EBA": "Extrastriate Body Area; selectively responds to bodies and body parts", | |
| "FBA-1": "Fusiform Body Area 1; body processing in ventral temporal cortex", | |
| "FBA-2": "Fusiform Body Area 2; complementary body processing region", | |
| "mTL-bodies": "Medial temporal lobe body area; body recognition with memory component", | |
| "OFA": "Occipital Face Area; early face-selective processing, face parts detection", | |
| "FFA-1": "Fusiform Face Area 1; core face recognition and identity processing", | |
| "FFA-2": "Fusiform Face Area 2; complementary face processing, holistic face representation", | |
| "mTL-faces": "Medial temporal lobe face area; face recognition with episodic memory", | |
| "aTL-faces": "Anterior temporal lobe face area; person identity and semantic knowledge", | |
| "OPA": "Occipital Place Area; processes local scene elements and spatial boundaries", | |
| "PPA": "Parahippocampal Place Area; processes scenes, buildings, spatial layouts", | |
| "RSC": "Retrosplenial Cortex; spatial navigation, scene-to-map coordinate transformation", | |
| "OWFA": "Occipital Word Form Area; early visual word processing", | |
| "VWFA-1": "Visual Word Form Area 1; processes written words and letter strings", | |
| "VWFA-2": "Visual Word Form Area 2; higher-level word form processing", | |
| "mfs-words": "Mid-fusiform sulcus word area; intermediate word processing", | |
| "mTL-words": "Medial temporal lobe word area; word recognition with memory", | |
| } | |
| NETWORK_FUNCTIONS = { | |
| "early_visual": "Early visual processing: edges, orientations, spatial frequencies, textures, colors. Active for all visual stimuli.", | |
| "body_selective": "Body-selective cortex: responds to human bodies, body parts, biological motion. Key for person perception.", | |
| "face_selective": "Face-selective cortex: responds to faces, facial features, identity. Critical for social perception.", | |
| "place_selective": "Place/scene-selective cortex: responds to spatial layouts, buildings, scenes, navigation cues.", | |
| "word_selective": "Word/reading-selective cortex: responds to written text, letter strings, word forms.", | |
| } | |
| # ============================================================ | |
| # Helper: enable only Dropout for MC sampling (keep BatchNorm in eval) | |
| # ============================================================ | |
| def enable_dropout_only(model): | |
| """Enable Dropout layers while keeping BatchNorm in eval mode. | |
| This is needed for MC Dropout uncertainty estimation with batch_size=1, | |
| because BatchNorm1d requires batch_size > 1 in training mode. | |
| """ | |
| for module in model.modules(): | |
| if isinstance(module, nn.Dropout): | |
| module.train() | |
| # ============================================================ | |
| # BrainEncoder model (must match training architecture exactly) | |
| # ============================================================ | |
| class BrainEncoder(nn.Module): | |
| def __init__(self, input_dim=4096, n_voxels=15724, hidden_dims=None, dropout=0.3, n_rois=24): | |
| super().__init__() | |
| if hidden_dims is None: | |
| hidden_dims = [2048, 2048, 1024] | |
| self.input_dim = input_dim | |
| self.n_voxels = n_voxels | |
| self.n_rois = n_rois | |
| layers = [] | |
| prev_dim = input_dim | |
| for h_dim in hidden_dims: | |
| layers.extend([ | |
| nn.Linear(prev_dim, h_dim), | |
| nn.BatchNorm1d(h_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| ]) | |
| prev_dim = h_dim | |
| self.backbone = nn.Sequential(*layers) | |
| self.general_head = nn.Linear(hidden_dims[-1], n_voxels) | |
| self.roi_attention = nn.ModuleDict() | |
| self.roi_heads = nn.ModuleDict() | |
| self.network_names = ["early_visual", "body_selective", "face_selective", | |
| "place_selective", "word_selective"] | |
| for net_name in self.network_names: | |
| self.roi_attention[net_name] = nn.Sequential( | |
| nn.Linear(hidden_dims[-1], 256), | |
| nn.GELU(), | |
| nn.Linear(256, hidden_dims[-1]), | |
| nn.Sigmoid(), | |
| ) | |
| self.roi_heads[net_name] = nn.Linear(hidden_dims[-1], n_voxels) | |
| self.register_buffer('roi_mask', torch.zeros(n_voxels, dtype=torch.long)) | |
| def set_roi_assignments(self, annot): | |
| for net_idx, (net_name, roi_ids) in enumerate(FUNCTIONAL_NETWORKS.items()): | |
| for roi_id in roi_ids: | |
| mask = (annot == roi_id) | |
| if len(mask) <= self.n_voxels: | |
| self.roi_mask[:len(mask)][mask[:self.n_voxels]] = net_idx + 1 | |
| def forward(self, x, return_intermediates=False): | |
| intermediates = {} | |
| backbone_out = self.backbone(x) | |
| intermediates['backbone'] = backbone_out.detach() | |
| pred = self.general_head(backbone_out) | |
| intermediates['general_pred'] = pred.detach() | |
| for net_idx, net_name in enumerate(self.network_names): | |
| if net_name in self.roi_attention: | |
| attn = self.roi_attention[net_name](backbone_out) | |
| weighted = backbone_out * attn | |
| roi_pred = self.roi_heads[net_name](weighted) | |
| mask = (self.roi_mask == net_idx + 1) | |
| if mask.any(): | |
| pred[:, mask] = roi_pred[:, mask] | |
| intermediates[f'roi_{net_name}'] = roi_pred.detach() | |
| if return_intermediates: | |
| return pred, intermediates | |
| return pred | |
| # ============================================================ | |
| # Model Manager - loads and caches models | |
| # ============================================================ | |
| class ModelManager: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.brain_encoder = None | |
| self.ridge_model = None | |
| self.clip_model = None | |
| self.clip_processor = None | |
| self.roi_annotations = None | |
| self.config = None | |
| self._loaded = False | |
| def load(self): | |
| if self._loaded: | |
| return | |
| from huggingface_hub import hf_hub_download | |
| logger.info(f"Loading models from {MODEL_REPO}...") | |
| # Load config | |
| try: | |
| config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json") | |
| with open(config_path) as f: | |
| self.config = json.load(f) | |
| logger.info(f"Config loaded: {self.config.get('architecture', 'unknown')}") | |
| except Exception as e: | |
| logger.warning(f"Config load failed: {e}") | |
| self.config = {} | |
| # Load ROI annotations | |
| try: | |
| annot_path = hf_hub_download(repo_id=MODEL_REPO, filename="roi_annotations.npy") | |
| self.roi_annotations = np.load(annot_path).flatten() | |
| logger.info(f"ROI annotations: {self.roi_annotations.shape}") | |
| except Exception as e: | |
| logger.warning(f"ROI annotations load failed: {e}") | |
| # Load brain encoder (optional - ridge is primary) | |
| try: | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt") | |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) | |
| model_config = checkpoint.get('config', {}) | |
| self.brain_encoder = BrainEncoder( | |
| input_dim=model_config.get('input_dim', 4096), | |
| n_voxels=model_config.get('n_voxels', 15724), | |
| hidden_dims=model_config.get('hidden_dims', [2048, 2048, 1024]), | |
| dropout=model_config.get('dropout', 0.3), | |
| ) | |
| self.brain_encoder.load_state_dict(checkpoint['model_state_dict']) | |
| self.brain_encoder.to(self.device).eval() | |
| if self.roi_annotations is not None: | |
| self.brain_encoder.set_roi_assignments(self.roi_annotations) | |
| # Free checkpoint memory | |
| del checkpoint | |
| logger.info("Brain encoder loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Brain encoder load failed (will use ridge only): {e}") | |
| self.brain_encoder = None | |
| # Load ridge model | |
| try: | |
| ridge_path = hf_hub_download(repo_id=MODEL_REPO, filename="ridge_model.pkl") | |
| with open(ridge_path, 'rb') as f: | |
| self.ridge_model = pickle.load(f) | |
| logger.info("Ridge model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Ridge model load failed: {e}") | |
| # Load CLIP | |
| try: | |
| from transformers import CLIPModel, CLIPProcessor | |
| self.clip_model = CLIPModel.from_pretrained( | |
| "openai/clip-vit-large-patch14", | |
| torch_dtype=torch.float32, | |
| ).to(self.device).eval() | |
| self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| logger.info("CLIP model loaded") | |
| except Exception as e: | |
| logger.error(f"CLIP load failed: {e}") | |
| raise | |
| self._loaded = True | |
| logger.info("All models loaded successfully!") | |
| def extract_features(self, image=None, text=None, audio=None): | |
| """Extract multimodal CLIP features.""" | |
| features_dict = {} | |
| if image is not None: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| inputs = self.clip_processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | |
| with torch.no_grad(): | |
| outputs = self.clip_model.vision_model(**inputs, output_hidden_states=True) | |
| cls_features = outputs.last_hidden_state[:, 0, :] | |
| projected = self.clip_model.visual_projection(cls_features) | |
| hidden_concat = [] | |
| for layer_idx in [6, 12, 18, 23]: | |
| h = outputs.hidden_states[layer_idx][:, 0, :] | |
| hidden_concat.append(h) | |
| multi_layer = torch.cat(hidden_concat, dim=-1) | |
| features_dict['image_projected'] = projected.cpu().float() | |
| features_dict['image_multi_layer'] = multi_layer.cpu().float() | |
| features_dict['modality'] = 'image' | |
| if text is not None and text.strip(): | |
| inputs = self.clip_processor(text=[text], return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | |
| with torch.no_grad(): | |
| text_outputs = self.clip_model.text_model(**inputs) | |
| pooled = text_outputs.pooler_output | |
| projected = self.clip_model.text_projection(pooled) | |
| # For text, repeat the projected features to match multi-layer dim | |
| # Text goes through the same brain encoder by tiling to 4096 | |
| text_multi = projected.repeat(1, 4096 // projected.shape[1] + 1)[:, :4096] | |
| features_dict['text_projected'] = projected.cpu().float() | |
| features_dict['text_multi_layer'] = text_multi.cpu().float() | |
| if 'modality' not in features_dict: | |
| features_dict['modality'] = 'text' | |
| else: | |
| features_dict['modality'] = 'image+text' | |
| if audio is not None: | |
| # Convert audio to spectrogram image for CLIP processing | |
| sr, audio_data = audio if isinstance(audio, tuple) else (16000, audio) | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.mean(axis=1) | |
| audio_data = audio_data.astype(np.float32) | |
| # Create mel spectrogram visualization | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(1, 1, figsize=(4, 4)) | |
| # Simple spectrogram using STFT | |
| n_fft = min(1024, len(audio_data)) | |
| hop_length = n_fft // 4 | |
| if len(audio_data) > n_fft: | |
| # Manual STFT | |
| n_frames = (len(audio_data) - n_fft) // hop_length + 1 | |
| spec = np.zeros((n_fft // 2 + 1, n_frames)) | |
| window = np.hanning(n_fft) | |
| for i in range(n_frames): | |
| start = i * hop_length | |
| frame = audio_data[start:start + n_fft] * window | |
| fft = np.fft.rfft(frame) | |
| spec[:, i] = np.abs(fft) | |
| spec_db = 20 * np.log10(spec + 1e-10) | |
| ax.imshow(spec_db, aspect='auto', origin='lower', cmap='viridis') | |
| else: | |
| ax.plot(audio_data[:1000]) | |
| ax.set_title("Audio Spectrogram") | |
| ax.axis('off') | |
| fig.canvas.draw() | |
| # Convert to image | |
| buf = fig.canvas.buffer_rgba() | |
| spec_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf).convert('RGB') | |
| plt.close(fig) | |
| # Process through CLIP as image | |
| inputs = self.clip_processor(images=spec_img, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | |
| with torch.no_grad(): | |
| outputs = self.clip_model.vision_model(**inputs, output_hidden_states=True) | |
| cls_features = outputs.last_hidden_state[:, 0, :] | |
| projected = self.clip_model.visual_projection(cls_features) | |
| hidden_concat = [] | |
| for layer_idx in [6, 12, 18, 23]: | |
| h = outputs.hidden_states[layer_idx][:, 0, :] | |
| hidden_concat.append(h) | |
| multi_layer = torch.cat(hidden_concat, dim=-1) | |
| features_dict['audio_projected'] = projected.cpu().float() | |
| features_dict['audio_multi_layer'] = multi_layer.cpu().float() | |
| if features_dict.get('modality') is None: | |
| features_dict['modality'] = 'audio' | |
| else: | |
| features_dict['modality'] = features_dict['modality'] + '+audio' | |
| return features_dict | |
| def predict_brain_activity(self, features_dict): | |
| """Run brain encoder forward pass using BOTH ridge and deep models.""" | |
| # Determine which features to use | |
| if 'image_multi_layer' in features_dict: | |
| input_features = features_dict['image_multi_layer'] | |
| elif 'text_multi_layer' in features_dict: | |
| input_features = features_dict['text_multi_layer'] | |
| elif 'audio_multi_layer' in features_dict: | |
| input_features = features_dict['audio_multi_layer'] | |
| else: | |
| raise ValueError("No features available for prediction") | |
| # If multimodal, average features | |
| all_modality_features = [] | |
| for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']: | |
| if key in features_dict: | |
| all_modality_features.append(features_dict[key]) | |
| if len(all_modality_features) > 1: | |
| input_features = torch.mean(torch.stack(all_modality_features), dim=0) | |
| input_features_np = input_features.cpu().numpy() | |
| input_features = input_features.to(self.device) | |
| # ββ Primary: Ridge Model (proven baseline from Algonauts 2023) ββ | |
| if self.ridge_model is not None: | |
| ridge = self.ridge_model | |
| X_norm = (input_features_np - ridge['feat_mean']) / ridge['feat_std'] | |
| pred_z = ridge['model'].predict(X_norm) | |
| pred_np = (pred_z * ridge['fmri_std'] + ridge['fmri_mean']).flatten() | |
| # Clip extreme values for better visualization (keep 99.5th percentile) | |
| clip_val = np.percentile(np.abs(pred_np), 99.5) | |
| pred_np = np.clip(pred_np, -clip_val, clip_val) | |
| else: | |
| # Fallback to deep encoder | |
| with torch.no_grad(): | |
| predictions, _ = self.brain_encoder(input_features, return_intermediates=True) | |
| pred_np = predictions.cpu().numpy().flatten() | |
| # ββ Deep encoder for intermediates and uncertainty ββ | |
| intermediates = {} | |
| if self.brain_encoder is not None: | |
| with torch.no_grad(): | |
| deep_pred, intermediates = self.brain_encoder(input_features, return_intermediates=True) | |
| # Compute uncertainty via MC Dropout | |
| # IMPORTANT: Only enable Dropout layers, keep BatchNorm in eval mode. | |
| # BatchNorm1d requires batch_size > 1 in training mode, but we have batch_size=1. | |
| self.brain_encoder.eval() # Ensure everything is in eval mode first | |
| enable_dropout_only(self.brain_encoder) # Selectively enable only Dropout | |
| mc_predictions = [] | |
| for _ in range(10): | |
| with torch.no_grad(): | |
| mc_pred = self.brain_encoder(input_features) | |
| mc_predictions.append(mc_pred.cpu().numpy().flatten()) | |
| self.brain_encoder.eval() # Restore full eval mode | |
| mc_predictions = np.array(mc_predictions) | |
| uncertainty = np.std(mc_predictions, axis=0) | |
| else: | |
| # Estimate uncertainty from ridge prediction variance across feature perturbation | |
| ridge = self.ridge_model | |
| mc_predictions = [] | |
| for _ in range(10): | |
| noise = np.random.normal(0, 0.01, size=input_features_np.shape) | |
| X_noisy = (input_features_np + noise - ridge['feat_mean']) / ridge['feat_std'] | |
| mp = ridge['model'].predict(X_noisy).flatten() | |
| mc_predictions.append(mp) | |
| mc_predictions = np.array(mc_predictions) | |
| uncertainty = np.std(mc_predictions, axis=0) | |
| # Compute modality contributions using ridge | |
| modality_contributions = {} | |
| if self.ridge_model is not None: | |
| ridge = self.ridge_model | |
| for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']: | |
| if key in features_dict: | |
| modality_name = key.split('_')[0] | |
| feat_np = features_dict[key].cpu().numpy() | |
| X_n = (feat_np - ridge['feat_mean']) / ridge['feat_std'] | |
| mp = (ridge['model'].predict(X_n) * ridge['fmri_std'] + ridge['fmri_mean']).flatten() | |
| clip_val_mod = np.percentile(np.abs(mp), 99.5) | |
| mp = np.clip(mp, -clip_val_mod, clip_val_mod) | |
| modality_contributions[modality_name] = mp | |
| # Compute ROI summaries using z-scored per-voxel predictions | |
| # This shows which regions are MORE or LESS activated compared to baseline | |
| if self.ridge_model is not None: | |
| baseline_mean = self.ridge_model['fmri_mean'] | |
| baseline_std = self.ridge_model['fmri_std'] | |
| # Z-score predictions relative to training distribution | |
| n_v = min(len(pred_np), len(baseline_mean)) | |
| pred_z = (pred_np[:n_v] - baseline_mean[:n_v]) / (baseline_std[:n_v] + 1e-8) | |
| else: | |
| pred_z = pred_np | |
| roi_summary = self._compute_roi_summary(pred_z, uncertainty) | |
| # Validation checks | |
| warnings = self._validate_predictions(pred_np) | |
| result = { | |
| 'predictions': pred_np, | |
| 'uncertainty': uncertainty, | |
| 'roi_summary': roi_summary, | |
| 'modality_contributions': modality_contributions, | |
| 'modality': features_dict.get('modality', 'unknown'), | |
| 'intermediates': {k: v.cpu().numpy() if torch.is_tensor(v) else v | |
| for k, v in intermediates.items()}, | |
| 'warnings': warnings, | |
| 'timestamp': datetime.now().isoformat(), | |
| } | |
| return result | |
| def _compute_roi_summary(self, predictions, uncertainty): | |
| """Compute per-ROI activation summaries.""" | |
| if self.roi_annotations is None: | |
| return {} | |
| annot = self.roi_annotations | |
| n_voxels = len(predictions) | |
| roi_summary = {} | |
| for roi_id, roi_name in ROI_NAMES.items(): | |
| mask = (annot[:n_voxels] == roi_id) if len(annot) >= n_voxels else np.zeros(n_voxels, dtype=bool) | |
| if mask.sum() == 0: | |
| continue | |
| roi_activations = predictions[mask] | |
| roi_uncertainty = uncertainty[mask] | |
| roi_summary[roi_name] = { | |
| 'mean_activation': float(np.mean(roi_activations)), | |
| 'max_activation': float(np.max(roi_activations)), | |
| 'min_activation': float(np.min(roi_activations)), | |
| 'std_activation': float(np.std(roi_activations)), | |
| 'mean_uncertainty': float(np.mean(roi_uncertainty)), | |
| 'n_voxels': int(mask.sum()), | |
| 'abs_mean': float(np.mean(np.abs(roi_activations))), | |
| 'known_function': ROI_FUNCTIONS.get(roi_name, "Unknown"), | |
| } | |
| return roi_summary | |
| def _validate_predictions(self, predictions): | |
| """Validation safeguards.""" | |
| warnings = [] | |
| if np.std(predictions) < 1e-6: | |
| warnings.append("β οΈ CONSTANT OUTPUT DETECTED: All voxels have near-identical values") | |
| if np.any(np.isnan(predictions)): | |
| warnings.append("β οΈ NaN VALUES DETECTED in predictions") | |
| if np.any(np.isinf(predictions)): | |
| warnings.append("β οΈ Infinite VALUES DETECTED in predictions") | |
| if np.max(np.abs(predictions)) > 50: | |
| warnings.append(f"β οΈ Unusually large activations detected (max |activation| = {np.max(np.abs(predictions)):.2f})") | |
| return warnings | |
| # ============================================================ | |
| # Grounded Q&A System | |
| # ============================================================ | |
| class GroundedQA: | |
| """ | |
| RAG-grounded Q&A system. | |
| The LLM is an INTERPRETER - it only explains model predictions. | |
| It does NOT generate independent neuroscience claims. | |
| """ | |
| def __init__(self): | |
| self.inference_client = None | |
| self._init_client() | |
| def _init_client(self): | |
| try: | |
| from huggingface_hub import InferenceClient | |
| self.inference_client = InferenceClient( | |
| provider="hf-inference", | |
| api_key=os.environ.get("HF_TOKEN", ""), | |
| ) | |
| logger.info("HF Inference Client initialized") | |
| except Exception as e: | |
| logger.warning(f"Inference client init failed: {e}") | |
| def build_context(self, brain_result): | |
| """Build structured context from model predictions for LLM grounding.""" | |
| roi_summary = brain_result.get('roi_summary', {}) | |
| modality = brain_result.get('modality', 'unknown') | |
| warnings = brain_result.get('warnings', []) | |
| modality_contributions = brain_result.get('modality_contributions', {}) | |
| # Sort ROIs by absolute mean activation | |
| sorted_rois = sorted( | |
| roi_summary.items(), | |
| key=lambda x: abs(x[1]['abs_mean']), | |
| reverse=True | |
| ) | |
| # Top activated regions | |
| top_regions = [] | |
| for roi_name, data in sorted_rois[:10]: | |
| top_regions.append( | |
| f"- {roi_name}: mean_activation={data['mean_activation']:.4f}, " | |
| f"abs_mean={data['abs_mean']:.4f}, uncertainty={data['mean_uncertainty']:.4f}, " | |
| f"n_voxels={data['n_voxels']}" | |
| ) | |
| # Network-level summaries | |
| network_summaries = {} | |
| for net_name, roi_ids in FUNCTIONAL_NETWORKS.items(): | |
| roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES] | |
| activations = [] | |
| for rn in roi_names_in_net: | |
| if rn in roi_summary: | |
| activations.append(roi_summary[rn]['abs_mean']) | |
| if activations: | |
| network_summaries[net_name] = { | |
| 'mean_abs_activation': np.mean(activations), | |
| 'max_abs_activation': np.max(activations), | |
| 'function': NETWORK_FUNCTIONS.get(net_name, ""), | |
| } | |
| sorted_networks = sorted( | |
| network_summaries.items(), | |
| key=lambda x: x[1]['mean_abs_activation'], | |
| reverse=True | |
| ) | |
| # Modality contributions | |
| modality_info = "" | |
| if modality_contributions: | |
| modality_info = "\n## Modality Contributions\n" | |
| for mod_name, mod_pred in modality_contributions.items(): | |
| modality_info += f"- {mod_name}: mean_abs_activation={np.mean(np.abs(mod_pred)):.4f}, std={np.std(mod_pred):.4f}\n" | |
| # Global prediction stats | |
| predictions = brain_result['predictions'] | |
| global_stats = ( | |
| f"- Total voxels predicted: {len(predictions)}\n" | |
| f"- Global mean activation: {np.mean(predictions):.4f}\n" | |
| f"- Global std: {np.std(predictions):.4f}\n" | |
| f"- Global range: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]\n" | |
| f"- Mean uncertainty: {np.mean(brain_result['uncertainty']):.4f}\n" | |
| ) | |
| context = f"""## Brain Activity Prediction Summary | |
| Input modality: {modality} | |
| ## Global Statistics | |
| {global_stats} | |
| ## Top 10 Activated Brain Regions (by absolute activation strength) | |
| {chr(10).join(top_regions)} | |
| ## Functional Network Activations (ranked by strength) | |
| """ | |
| for net_name, net_data in sorted_networks: | |
| context += ( | |
| f"- {net_name}: mean_abs={net_data['mean_abs_activation']:.4f}, " | |
| f"max_abs={net_data['max_abs_activation']:.4f}\n" | |
| f" Known function: {net_data['function']}\n" | |
| ) | |
| context += modality_info | |
| if warnings: | |
| context += "\n## Warnings\n" | |
| for w in warnings: | |
| context += f"- {w}\n" | |
| # ROI functional labels | |
| context += "\n## ROI Functional Reference\n" | |
| for roi_name in [r[0] for r in sorted_rois[:10]]: | |
| if roi_name in ROI_FUNCTIONS: | |
| context += f"- {roi_name}: {ROI_FUNCTIONS[roi_name]}\n" | |
| return context | |
| def answer(self, question, brain_result): | |
| """Answer a question grounded in model predictions.""" | |
| context = self.build_context(brain_result) | |
| system_prompt = """You are a neuroscience interpreter for a brain encoding model. | |
| Your role is STRICTLY to interpret and explain the model's predicted brain activity patterns. | |
| CRITICAL RULES: | |
| 1. ONLY reference data provided in the context below. Never invent neuroscience claims. | |
| 2. Always distinguish between: | |
| - "Predicted activation" (what the model outputs) | |
| - "Known neuroscience association" (established findings about brain regions) | |
| - "Possible interpretation" (your inference connecting the two) | |
| 3. Include uncertainty statements. Use phrases like "the model predicts", "this is consistent with", "one possible interpretation is" | |
| 4. NEVER make definitive claims about emotions, consciousness, or behavior from brain activity alone. | |
| 5. Always cite specific regions, activation values, and confidence levels from the context. | |
| 6. If the question cannot be answered from the provided data, say so explicitly. | |
| 7. Keep answers concise but informative (2-4 paragraphs max). | |
| You are an INTERPRETER of model outputs, not an independent neuroscience oracle.""" | |
| user_prompt = f"""## Model Prediction Context | |
| {context} | |
| ## User Question | |
| {question} | |
| Please answer based ONLY on the model prediction data above. Cite specific regions and values.""" | |
| if self.inference_client is None: | |
| return self._fallback_answer(question, brain_result, context) | |
| try: | |
| response = self.inference_client.chat.completions.create( | |
| model="Qwen/Qwen2.5-72B-Instruct", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=800, | |
| temperature=0.3, | |
| ) | |
| answer = response.choices[0].message.content | |
| # Add grounding footer | |
| answer += "\n\n---\n*This interpretation is based on model predictions with " | |
| mean_unc = np.mean(brain_result['uncertainty']) | |
| answer += f"mean uncertainty={mean_unc:.4f}. " | |
| answer += "Predictions are from a brain encoder trained on NSD (Natural Scenes Dataset) fMRI data.*" | |
| return answer | |
| except Exception as e: | |
| logger.warning(f"LLM inference failed: {e}") | |
| return self._fallback_answer(question, brain_result, context) | |
| def _fallback_answer(self, question, brain_result, context): | |
| """Structured fallback when LLM is unavailable.""" | |
| roi_summary = brain_result.get('roi_summary', {}) | |
| sorted_rois = sorted( | |
| roi_summary.items(), | |
| key=lambda x: abs(x[1]['abs_mean']), | |
| reverse=True | |
| ) | |
| answer = "## Brain Activity Interpretation\n\n" | |
| answer += f"**Input modality:** {brain_result.get('modality', 'unknown')}\n\n" | |
| answer += "### Top Activated Regions\n" | |
| for roi_name, data in sorted_rois[:5]: | |
| answer += ( | |
| f"- **{roi_name}** (activation={data['mean_activation']:.4f}, " | |
| f"uncertainty={data['mean_uncertainty']:.4f}): " | |
| f"{ROI_FUNCTIONS.get(roi_name, 'Unknown function')}\n" | |
| ) | |
| answer += "\n### Network-Level Summary\n" | |
| for net_name, roi_ids in FUNCTIONAL_NETWORKS.items(): | |
| roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES] | |
| activations = [roi_summary[rn]['abs_mean'] for rn in roi_names_in_net if rn in roi_summary] | |
| if activations: | |
| mean_act = np.mean(activations) | |
| answer += f"- **{net_name}**: mean_abs_activation={mean_act:.4f} β {NETWORK_FUNCTIONS.get(net_name, '')}\n" | |
| answer += f"\n*Note: LLM interpretation unavailable. Showing structured prediction summary. " | |
| answer += f"Mean uncertainty: {np.mean(brain_result['uncertainty']):.4f}*" | |
| return answer | |
| # ============================================================ | |
| # Transparency Logger | |
| # ============================================================ | |
| class TransparencyLogger: | |
| """Logs all inputs, intermediates, and outputs for traceability.""" | |
| def __init__(self): | |
| self.logs = [] | |
| def log_inference(self, inputs, features_dict, brain_result, qa_answer=None): | |
| entry = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'inputs': { | |
| 'has_image': inputs.get('image') is not None, | |
| 'has_text': inputs.get('text') is not None and inputs.get('text', '').strip() != '', | |
| 'has_audio': inputs.get('audio') is not None, | |
| 'text_content': inputs.get('text', '')[:200], | |
| }, | |
| 'features': { | |
| 'modality': features_dict.get('modality', 'unknown'), | |
| 'feature_norms': {}, | |
| }, | |
| 'predictions': { | |
| 'n_voxels': len(brain_result['predictions']), | |
| 'pred_mean': float(np.mean(brain_result['predictions'])), | |
| 'pred_std': float(np.std(brain_result['predictions'])), | |
| 'pred_range': [float(np.min(brain_result['predictions'])), | |
| float(np.max(brain_result['predictions']))], | |
| 'uncertainty_mean': float(np.mean(brain_result['uncertainty'])), | |
| }, | |
| 'roi_summary_sent_to_llm': list(brain_result.get('roi_summary', {}).keys()), | |
| 'warnings': brain_result.get('warnings', []), | |
| 'qa_answer_length': len(qa_answer) if qa_answer else 0, | |
| } | |
| # Feature norms | |
| for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']: | |
| if key in features_dict: | |
| entry['features']['feature_norms'][key] = float(features_dict[key].norm().item()) | |
| self.logs.append(entry) | |
| return entry | |
| def get_log_text(self): | |
| return json.dumps(self.logs[-5:], indent=2, default=str) | |
| # ============================================================ | |
| # Visualization helpers | |
| # ============================================================ | |
| def create_brain_activation_plot(brain_result, roi_annotations): | |
| """Create brain activation visualization.""" | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| roi_summary = brain_result.get('roi_summary', {}) | |
| if not roi_summary: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No ROI data available", x=0.5, y=0.5) | |
| return fig | |
| # Create multi-panel figure | |
| fig = make_subplots( | |
| rows=2, cols=2, | |
| subplot_titles=( | |
| "ROI Activation Strengths", | |
| "Functional Network Summary", | |
| "Activation Uncertainty", | |
| "Activation Distribution", | |
| ), | |
| specs=[ | |
| [{"type": "bar"}, {"type": "bar"}], | |
| [{"type": "bar"}, {"type": "histogram"}], | |
| ] | |
| ) | |
| # Panel 1: ROI activations | |
| sorted_rois = sorted(roi_summary.items(), key=lambda x: abs(x[1]['abs_mean']), reverse=True)[:15] | |
| roi_names = [r[0] for r in sorted_rois] | |
| roi_activations = [r[1]['mean_activation'] for r in sorted_rois] | |
| roi_colors = [] | |
| for r in sorted_rois: | |
| name = r[0] | |
| for net_name, roi_ids in FUNCTIONAL_NETWORKS.items(): | |
| roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES] | |
| if name in roi_names_in_net: | |
| color_map = { | |
| "early_visual": "#4CAF50", | |
| "body_selective": "#FF9800", | |
| "face_selective": "#E91E63", | |
| "place_selective": "#2196F3", | |
| "word_selective": "#9C27B0", | |
| } | |
| roi_colors.append(color_map.get(net_name, "#666")) | |
| break | |
| else: | |
| roi_colors.append("#666") | |
| fig.add_trace( | |
| go.Bar(x=roi_names, y=roi_activations, marker_color=roi_colors, name="Activation"), | |
| row=1, col=1 | |
| ) | |
| # Panel 2: Network summary | |
| net_names = [] | |
| net_activations = [] | |
| net_colors_list = [] | |
| color_map = { | |
| "early_visual": "#4CAF50", | |
| "body_selective": "#FF9800", | |
| "face_selective": "#E91E63", | |
| "place_selective": "#2196F3", | |
| "word_selective": "#9C27B0", | |
| } | |
| for net_name, roi_ids in FUNCTIONAL_NETWORKS.items(): | |
| roi_names_in_net = [ROI_NAMES[rid] for rid in roi_ids if rid in ROI_NAMES] | |
| activations = [roi_summary[rn]['abs_mean'] for rn in roi_names_in_net if rn in roi_summary] | |
| if activations: | |
| net_names.append(net_name.replace("_", " ").title()) | |
| net_activations.append(np.mean(activations)) | |
| net_colors_list.append(color_map.get(net_name, "#666")) | |
| fig.add_trace( | |
| go.Bar(x=net_names, y=net_activations, marker_color=net_colors_list, name="Network"), | |
| row=1, col=2 | |
| ) | |
| # Panel 3: Uncertainty | |
| roi_uncertainty = [r[1]['mean_uncertainty'] for r in sorted_rois] | |
| fig.add_trace( | |
| go.Bar(x=roi_names, y=roi_uncertainty, marker_color='rgba(255,0,0,0.5)', name="Uncertainty"), | |
| row=2, col=1 | |
| ) | |
| # Panel 4: Distribution | |
| predictions = brain_result['predictions'] | |
| fig.add_trace( | |
| go.Histogram(x=predictions[::10], nbinsx=50, name="Activations", marker_color='#4CAF50'), | |
| row=2, col=2 | |
| ) | |
| fig.update_layout( | |
| height=700, | |
| showlegend=False, | |
| title_text="Brain Activity Predictions", | |
| template="plotly_white", | |
| ) | |
| return fig | |
| def create_modality_contribution_plot(brain_result): | |
| """Create modality contribution visualization.""" | |
| import plotly.graph_objects as go | |
| contributions = brain_result.get('modality_contributions', {}) | |
| if len(contributions) <= 1: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Single modality input - no comparison available", x=0.5, y=0.5) | |
| return fig | |
| fig = go.Figure() | |
| for mod_name, mod_pred in contributions.items(): | |
| # Show distribution of activations per modality | |
| fig.add_trace(go.Histogram( | |
| x=mod_pred[::10], | |
| name=mod_name.capitalize(), | |
| opacity=0.6, | |
| nbinsx=50, | |
| )) | |
| fig.update_layout( | |
| title="Modality Contributions to Brain Activity", | |
| xaxis_title="Predicted Activation", | |
| yaxis_title="Count", | |
| barmode='overlay', | |
| template="plotly_white", | |
| height=400, | |
| ) | |
| return fig | |
| # ============================================================ | |
| # Gradio Application | |
| # ============================================================ | |
| def build_gradio_app(): | |
| import gradio as gr | |
| # Global state | |
| manager = ModelManager() | |
| qa_system = GroundedQA() | |
| transparency_log = TransparencyLogger() | |
| current_result = {"value": None} | |
| def initialize(): | |
| try: | |
| manager.load() | |
| return "β Models loaded successfully!" | |
| except Exception as e: | |
| return f"β Error loading models: {e}" | |
| def process_input(image, text, audio): | |
| """Main inference pipeline.""" | |
| if not manager._loaded: | |
| manager.load() | |
| if image is None and (text is None or text.strip() == '') and audio is None: | |
| return "Please provide at least one input (image, text, or audio).", None, None, "" | |
| try: | |
| # Step 1: Extract features | |
| features = manager.extract_features(image=image, text=text, audio=audio) | |
| # Step 2: Predict brain activity | |
| result = manager.predict_brain_activity(features) | |
| current_result["value"] = result | |
| # Step 3: Create visualizations | |
| brain_plot = create_brain_activation_plot(result, manager.roi_annotations) | |
| modality_plot = create_modality_contribution_plot(result) | |
| # Step 4: Log for transparency | |
| log_entry = transparency_log.log_inference( | |
| {'image': image, 'text': text, 'audio': audio}, | |
| features, result | |
| ) | |
| # Summary text | |
| roi_summary = result.get('roi_summary', {}) | |
| sorted_rois = sorted(roi_summary.items(), key=lambda x: abs(x[1]['abs_mean']), reverse=True) | |
| summary = f"**Modality:** {result['modality']}\n" | |
| summary += f"**Voxels predicted:** {len(result['predictions'])}\n" | |
| summary += f"**Mean uncertainty:** {np.mean(result['uncertainty']):.4f}\n\n" | |
| summary += "**Top 5 Activated Regions:**\n" | |
| for roi_name, data in sorted_rois[:5]: | |
| summary += f"- {roi_name}: {data['mean_activation']:.4f} (Β±{data['mean_uncertainty']:.4f})\n" | |
| if result['warnings']: | |
| summary += "\n**Warnings:**\n" | |
| for w in result['warnings']: | |
| summary += f"- {w}\n" | |
| return summary, brain_plot, modality_plot, json.dumps(log_entry, indent=2, default=str) | |
| except Exception as e: | |
| import traceback | |
| return f"Error: {e}\n{traceback.format_exc()}", None, None, "" | |
| def ask_question(question, history): | |
| """Q&A with grounded interpretation.""" | |
| if current_result["value"] is None: | |
| history = history or [] | |
| history.append({"role": "user", "content": question}) | |
| history.append({"role": "assistant", "content": "Please run an inference first (provide an input in the Stimulus tab) before asking questions."}) | |
| return history, "" | |
| history = history or [] | |
| history.append({"role": "user", "content": question}) | |
| answer = qa_system.answer(question, current_result["value"]) | |
| history.append({"role": "assistant", "content": answer}) | |
| # Log Q&A | |
| transparency_log.log_inference( | |
| {'text': question}, | |
| {'modality': 'qa'}, | |
| current_result["value"], | |
| qa_answer=answer, | |
| ) | |
| return history, "" | |
| def get_transparency_log(): | |
| return transparency_log.get_log_text() | |
| # Build UI | |
| with gr.Blocks(title="Multimodal Brain Encoder") as demo: | |
| gr.Markdown(""" | |
| # π§ Multimodal Brain Encoder | |
| **A real brain encoding model trained on the Natural Scenes Dataset (NSD)** | |
| This system predicts brain activity (fMRI voxel responses) from multimodal inputs using: | |
| - **CLIP ViT-L/14** for feature extraction (multi-layer: layers 6, 12, 18, 24) | |
| - **Deep Brain Encoder** with ROI-specific attention heads (trained on NSD subj01) | |
| - **Ridge Regression** baseline (Algonauts 2023 recipe) | |
| - **Grounded LLM Q&A** that only interprets model predictions | |
| All predictions are from real model forward passes with learned weights. | |
| """) | |
| status = gr.Textbox(label="Status", value="Click 'Load Models' to initialize") | |
| load_btn = gr.Button("π Load Models", variant="primary") | |
| load_btn.click(fn=initialize, outputs=status) | |
| with gr.Tabs(): | |
| # Tab 1: Input & Prediction | |
| with gr.Tab("π― Stimulus Input & Brain Prediction"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Visual Stimulus (Image)") | |
| text_input = gr.Textbox( | |
| label="Text Input", | |
| placeholder="Enter a description or sentence...", | |
| lines=3, | |
| ) | |
| audio_input = gr.Audio(type="numpy", label="Audio Input") | |
| predict_btn = gr.Button("π§ Predict Brain Activity", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| summary_output = gr.Markdown(label="Prediction Summary") | |
| brain_plot = gr.Plot(label="Brain Activity Visualization") | |
| modality_plot = gr.Plot(label="Modality Contributions") | |
| predict_btn.click( | |
| fn=process_input, | |
| inputs=[image_input, text_input, audio_input], | |
| outputs=[summary_output, brain_plot, modality_plot, gr.Textbox(visible=False)], | |
| ) | |
| # Tab 2: Q&A | |
| with gr.Tab("π¬ Grounded Q&A"): | |
| gr.Markdown(""" | |
| ### Ask questions about the predicted brain activity | |
| The LLM interpreter will answer based ONLY on: | |
| - Predicted activation maps and ROI summaries | |
| - Known functional labels from brain atlases | |
| - Modality attribution outputs | |
| - Uncertainty estimates | |
| It will NOT make independent neuroscience claims. | |
| """) | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| label="Brain Activity Q&A", | |
| height=400, | |
| ) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="e.g., Which brain regions are most activated? What does the face-selective network response mean?", | |
| scale=4, | |
| ) | |
| ask_btn = gr.Button("Ask", variant="primary", scale=1) | |
| ask_btn.click( | |
| fn=ask_question, | |
| inputs=[question_input, chatbot], | |
| outputs=[chatbot, question_input], | |
| ) | |
| question_input.submit( | |
| fn=ask_question, | |
| inputs=[question_input, chatbot], | |
| outputs=[chatbot, question_input], | |
| ) | |
| gr.Markdown(""" | |
| **Example questions:** | |
| - "What are the most activated brain regions for this input?" | |
| - "Is the face-selective network responding? What might that mean?" | |
| - "How confident is the model in these predictions?" | |
| - "How does the visual input differ from the text input in brain response?" | |
| - "What does high PPA activation suggest about this image?" | |
| """) | |
| # Tab 3: Transparency Log | |
| with gr.Tab("π Transparency Log"): | |
| gr.Markdown("### Full inference traceability log") | |
| gr.Markdown("Every inference is logged with inputs, features, predictions, and Q&A answers.") | |
| log_output = gr.Code(language="json", label="Recent Logs") | |
| refresh_log_btn = gr.Button("π Refresh Log") | |
| refresh_log_btn.click(fn=get_transparency_log, outputs=log_output) | |
| # Tab 4: Model Info | |
| with gr.Tab("βΉοΈ Model Information"): | |
| gr.Markdown(f""" | |
| ### Architecture Details | |
| | Component | Details | | |
| |-----------|---------| | |
| | Feature Extractor | CLIP ViT-L/14 (openai/clip-vit-large-patch14) | | |
| | Feature Layers | Layers 6, 12, 18, 24 (CLS tokens concatenated = 4096-dim) | | |
| | Brain Encoder | 4096 β 2048 β 2048 β 1024 β N_voxels | | |
| | Activations | GELU + BatchNorm + Dropout(0.3) | | |
| | ROI Heads | 5 functional network heads with learned attention | | |
| | Ridge Baseline | sklearn RidgeCV with 17 alphas (1e-2 to 1e6) | | |
| | Training Data | NSD subj01 (~8,859 train, ~300 val images) | | |
| | fMRI Resolution | 7T, ~15,724 voxels (NSD general cortical mask) | | |
| | Uncertainty | MC Dropout (10 forward passes) | | |
| ### Brain Regions (24 ROIs from NSD) | |
| | Network | Regions | Function | | |
| |---------|---------|----------| | |
| | Early Visual | V1v, V1d, V2v, V2d, V3v, V3d, hV4 | Basic visual processing | | |
| | Body Selective | EBA, FBA-1, FBA-2, mTL-bodies | Body/person perception | | |
| | Face Selective | OFA, FFA-1, FFA-2, mTL-faces, aTL-faces | Face recognition | | |
| | Place Selective | OPA, PPA, RSC | Scene/navigation | | |
| | Word Selective | OWFA, VWFA-1, VWFA-2, mfs-words, mTL-words | Reading/text | | |
| ### References | |
| - Natural Scenes Dataset: Allen et al. 2022, Nature Neuroscience | |
| - Algonauts 2023: Gifford et al. 2023 | |
| - CLIP: Radford et al. 2021 | |
| - Model repo: [{MODEL_REPO}](https://huggingface.co/{MODEL_REPO}) | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_gradio_app() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |