Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import nibabel as nib | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Tuple, Optional | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from einops.layers.torch import Rearrange | |
| from scipy.ndimage import zoom | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from nilearn import plotting | |
| import matplotlib.gridspec as gridspec | |
| class Config: | |
| VOLUME_SIZE: Tuple[int, int, int] = (64, 64, 30) | |
| EMBED_DIM: int = 256 | |
| NUM_HEADS: int = 8 | |
| NUM_LAYERS: int = 6 | |
| DROPOUT: float = 0.1 | |
| TASK_DIM: int = 512 | |
| class HierarchicalAttention(nn.Module): | |
| def __init__(self, dim, heads=8): | |
| super().__init__() | |
| self.local_attn = nn.MultiheadAttention(dim, heads, batch_first=True) | |
| self.global_attn = nn.MultiheadAttention(dim, heads, batch_first=True) | |
| self.merge = nn.Linear(dim * 2, dim) | |
| self.task_gate = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x, task_embed=None): | |
| local_out = self.local_attn(x, x, x)[0] | |
| if task_embed is not None: | |
| x = x * self.task_gate(task_embed).unsqueeze(1) | |
| global_out = self.global_attn(x, x, x)[0] | |
| return self.merge(torch.cat([local_out, global_out], dim=-1)) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(config.EMBED_DIM) | |
| self.attn = nn.MultiheadAttention( | |
| config.EMBED_DIM, | |
| config.NUM_HEADS, | |
| dropout=config.DROPOUT, | |
| batch_first=True | |
| ) | |
| self.norm2 = nn.LayerNorm(config.EMBED_DIM) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(config.EMBED_DIM, config.EMBED_DIM * 4), | |
| nn.GELU(), | |
| nn.Dropout(config.DROPOUT), | |
| nn.Linear(config.EMBED_DIM * 4, config.EMBED_DIM) | |
| ) | |
| self.task_gate = nn.Sequential( | |
| nn.Linear(config.EMBED_DIM, config.EMBED_DIM), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x, task): | |
| h = self.norm1(x) | |
| h = self.attn(h, h, h)[0] | |
| g = self.task_gate(task).unsqueeze(1) | |
| x = x + h * g | |
| h = self.norm2(x) | |
| h = self.mlp(h) | |
| x = x + h * g | |
| return x | |
| class WaveletTemporal(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.embed_dim = config.EMBED_DIM | |
| self.spatial_proj = nn.Conv3d(1, config.EMBED_DIM, 1) | |
| self.temporal_proj = nn.Conv3d( | |
| config.EMBED_DIM, | |
| config.EMBED_DIM, | |
| (3,1,1), | |
| padding=(1,0,0) | |
| ) | |
| self.pool = nn.AdaptiveAvgPool3d((15, 32, 32)) | |
| def forward(self, x): | |
| b, t, h, d, w = x.shape | |
| x = x.reshape(b, 1, t, h, w*d) | |
| x = self.spatial_proj(x) | |
| x = self.temporal_proj(x) | |
| return self.pool(x) | |
| class SequentialBrainViT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.temporal = WaveletTemporal(config) | |
| self.pool = nn.Sequential( | |
| nn.LayerNorm([config.EMBED_DIM, 15, 32, 32]), | |
| nn.AdaptiveAvgPool3d((5, 16, 16)), | |
| Rearrange('b c t h w -> b (t h w) c') | |
| ) | |
| self.num_patches = 5 * 16 * 16 | |
| self.task_embed = nn.Embedding(4, config.TASK_DIM) | |
| self.task_proj = nn.Sequential( | |
| nn.Linear(config.TASK_DIM, config.EMBED_DIM), | |
| nn.LayerNorm(config.EMBED_DIM), | |
| nn.GELU() | |
| ) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, config.EMBED_DIM)) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.EMBED_DIM)) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(config) | |
| for _ in range(config.NUM_LAYERS) | |
| ]) | |
| self.shared_proj = nn.Sequential( | |
| nn.LayerNorm(config.EMBED_DIM), | |
| nn.Linear(config.EMBED_DIM, config.EMBED_DIM * 2), | |
| nn.GELU(), | |
| nn.Linear(config.EMBED_DIM * 2, config.EMBED_DIM), | |
| nn.LayerNorm(config.EMBED_DIM), | |
| nn.Dropout(config.DROPOUT) | |
| ) | |
| self.heads = nn.ModuleDict({ | |
| 'learning_stage': nn.Sequential( | |
| nn.LayerNorm(config.EMBED_DIM), | |
| nn.Linear(config.EMBED_DIM, 1), | |
| nn.Sigmoid() | |
| ), | |
| 'region_activation': nn.Sequential( | |
| nn.LayerNorm(config.EMBED_DIM), | |
| nn.Linear(config.EMBED_DIM, 116) | |
| ), | |
| 'temporal_pattern': nn.Sequential( | |
| nn.LayerNorm(config.EMBED_DIM), | |
| nn.Linear(config.EMBED_DIM, 30) | |
| ) | |
| }) | |
| self._init_weights() | |
| def _init_weights(self): | |
| nn.init.normal_(self.cls_token, std=0.02) | |
| nn.init.normal_(self.pos_embed, std=0.02) | |
| for n, m in self.named_modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x, task_ids): | |
| x = self.temporal(x) | |
| x = self.pool(x) | |
| cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) | |
| x = torch.cat([cls_tokens, x], dim=1) | |
| x = x + self.pos_embed[:,:x.shape[1]] | |
| task = self.task_proj(self.task_embed(task_ids)) | |
| for block in self.blocks: | |
| x = block(x, task) | |
| x = self.shared_proj(x) | |
| return { | |
| 'learning_stage': self.heads['learning_stage'](x[:,0]), | |
| 'region_activation': self.heads['region_activation'](x.mean(1)), | |
| 'temporal_pattern': self.heads['temporal_pattern'](x[:,0]) | |
| } | |
| def preprocess_volume(vol, target_size=(64, 64, 30)): | |
| if vol.ndim == 4: | |
| vol = vol[None] | |
| b,t,h,w,d = vol.shape | |
| target_h, target_w, target_d = target_size | |
| vol = zoom(vol, ( | |
| 1, 1, | |
| target_h/h, | |
| target_w/w, | |
| target_d/d | |
| ), order=1) | |
| vol = (vol - vol.mean((1,2,3,4), keepdims=True)) / (vol.std((1,2,3,4), keepdims=True) + 1e-8) | |
| return torch.from_numpy(vol).float() | |
| def plot_brain_slices(data, learning_stage): | |
| fig = plt.figure(figsize=(15, 5)) | |
| mean_activation = data.mean(axis=0) | |
| for i, slice_idx in enumerate([mean_activation.shape[-1]//4, | |
| mean_activation.shape[-1]//2, | |
| 3*mean_activation.shape[-1]//4]): | |
| plt.subplot(1, 3, i+1) | |
| plt.imshow(mean_activation[...,slice_idx].T, cmap='hot') | |
| plt.colorbar() | |
| plt.title(f'slice z={slice_idx}\nlearning: {learning_stage:.3f}') | |
| plt.axis('off') | |
| return fig | |
| def interpret_learning_stage(score): | |
| if score < 0.2: | |
| return "NOVICE: minimal task familiarity, primarily exploratory behavior" | |
| elif score < 0.4: | |
| return "EARLY LEARNING: basic pattern recognition emerging" | |
| elif score < 0.6: | |
| return "INTERMEDIATE: developing systematic approach" | |
| elif score < 0.8: | |
| return "ADVANCED: robust strategy application" | |
| else: | |
| return "EXPERT: automated processing, highly optimized performance" | |
| def plot_results(data, region_acts, temporal_pattern, learning_stage): | |
| fig = plt.figure(figsize=(16, 9)) | |
| gs = gridspec.GridSpec(2, 2, height_ratios=[6, 4]) | |
| ax1 = plt.subplot(gs[0, :]) | |
| mean_activation = data.mean(axis=0) | |
| slice_idx = mean_activation.shape[-1]//2 | |
| brain_slice = mean_activation[...,slice_idx] | |
| peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape) | |
| peak_val = brain_slice[peak_coords] | |
| im = ax1.imshow(brain_slice.T, cmap='hot') | |
| plt.colorbar(im, ax=ax1) | |
| ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15) | |
| learning_desc = interpret_learning_stage(learning_stage) | |
| ax1.set_title(f'brain activation map (axial slice z={slice_idx})\n{learning_desc}', | |
| fontsize=12, pad=20) | |
| ax1.axis('off') | |
| ax2 = plt.subplot(gs[1, 0]) | |
| top_n = 5 | |
| region_ranking = np.argsort(-region_acts.flatten())[:top_n] | |
| region_data = region_acts.reshape(1,-1) | |
| sns.heatmap(region_data, cmap='RdBu_r', center=0, ax=ax2) | |
| ax2.set_title('regional activity profile\n' + | |
| 'top regions: ' + ', '.join(f'{r}' for r in region_ranking)) | |
| ax3 = plt.subplot(gs[1, 1]) | |
| ax3.plot(temporal_pattern.squeeze(), 'k-', linewidth=2) | |
| ax3.set_title('temporal evolution') | |
| ax3.set_xlabel('time (volumes)') | |
| ax3.set_ylabel('activation (a.u.)') | |
| ax3.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def process_fmri(file_obj): | |
| try: | |
| img = nib.load(file_obj.name) | |
| data = img.get_fdata(dtype=np.float32) | |
| if data.ndim == 3: | |
| data = data[None,...] | |
| elif data.ndim != 4: | |
| return f"error: volume must be 3D/4D, got {data.ndim}D", None | |
| t,h,w,d = data.shape | |
| if t < 1 or h < 16 or w < 16 or d < 8: | |
| return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None | |
| if t > 1000 or h > 256 or w > 256 or d > 256: | |
| return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None | |
| data = data.reshape(1, t, h, w, d) | |
| data = preprocess_volume(data) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| results = {} | |
| figs = [] | |
| for stage in ['full', 'region', 'temporal']: | |
| try: | |
| model = SequentialBrainViT(Config()) | |
| model._init_weights() | |
| ckpt = torch.load(f'best_{stage}.pt', map_location=device) | |
| missing = model.load_state_dict(ckpt['model'], strict=False) | |
| if missing: | |
| print(f"warning - {stage} missing keys:", missing) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(data.to(device), torch.tensor([0]).to(device)) | |
| results[stage] = { | |
| 'learning_stage': float(outputs['learning_stage'].cpu().mean()), | |
| 'region_activation': outputs['region_activation'].cpu().numpy(), | |
| 'temporal_pattern': outputs['temporal_pattern'].cpu().numpy() | |
| } | |
| fig = plot_results( | |
| data[0].cpu().numpy(), | |
| results[stage]['region_activation'], | |
| results[stage]['temporal_pattern'], | |
| results[stage]['learning_stage'] | |
| ) | |
| figs.append(fig) | |
| plt.close() | |
| except Exception as e: | |
| return f"error in {stage} model: {str(e)}", None | |
| stage_results = "fMRI ANALYSIS SUMMARY\n" + "="*50 + "\n\n" | |
| for stage, res in results.items(): | |
| stage_results += f"MODEL: {stage}\n" | |
| stage_results += f"learning assessment: {interpret_learning_stage(res['learning_stage'])}\n" | |
| stage_results += f"confidence score: {res['learning_stage']:.3f}\n" | |
| stage_results += f"dominant regions: {', '.join(str(r) for r in np.argsort(-res['region_activation'])[:3])}\n" | |
| stage_results += "-"*50 + "\n\n" | |
| return stage_results, figs[0] | |
| except Exception as e: | |
| return f"error processing file: {str(e)}", None | |
| iface = gr.Interface( | |
| fn=process_fmri, | |
| inputs=gr.File(label="Supports standard NIFTI format (.nii/.nii.gz)"), | |
| outputs=[ | |
| gr.Textbox( | |
| label="Analysis Results", | |
| placeholder="upload fMRI scan to begin...", | |
| lines=10 | |
| ), | |
| gr.Plot(label="Neural Activity Analysis") | |
| ], | |
| title="🧠 Learned Spectrum", | |
| description=""" | |
| ### fMRI Learning Stage Classification with Vision Transformers | |
| """, | |
| article=""" | |
| ### interpretation guide | |
| - learning stage: ranges from novice (0.0) to expert (1.0) | |
| - brain map: warmer colors = higher activation | |
| - regional profile: shows activity across 116 brain regions | |
| - temporal pattern: activation changes over time | |
| """, | |
| theme="default", | |
| examples=[], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |
| app = gr.mount_gradio_app(iface) |