learnedSpectrum / app.py
twarner's picture
3d->4d proj
d623d88
raw
history blame
11.9 kB
import os
import re
import torch
import gradio as gr
import numpy as np
import nibabel as nib
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import seaborn as sns
from nilearn import plotting
import matplotlib.gridspec as gridspec
@dataclass
class Config:
VOLUME_SIZE: Tuple[int, int, int] = (64, 64, 30)
EMBED_DIM: int = 256
NUM_HEADS: int = 8
NUM_LAYERS: int = 6
DROPOUT: float = 0.1
TASK_DIM: int = 512
class HierarchicalAttention(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
self.local_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.global_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.merge = nn.Linear(dim * 2, dim)
self.task_gate = nn.Sequential(
nn.Linear(dim, dim),
nn.Sigmoid()
)
def forward(self, x, task_embed=None):
local_out = self.local_attn(x, x, x)[0]
if task_embed is not None:
x = x * self.task_gate(task_embed).unsqueeze(1)
global_out = self.global_attn(x, x, x)[0]
return self.merge(torch.cat([local_out, global_out], dim=-1))
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.norm1 = nn.LayerNorm(config.EMBED_DIM)
self.attn = nn.MultiheadAttention(
config.EMBED_DIM,
config.NUM_HEADS,
dropout=config.DROPOUT,
batch_first=True
)
self.norm2 = nn.LayerNorm(config.EMBED_DIM)
self.mlp = nn.Sequential(
nn.Linear(config.EMBED_DIM, config.EMBED_DIM * 4),
nn.GELU(),
nn.Dropout(config.DROPOUT),
nn.Linear(config.EMBED_DIM * 4, config.EMBED_DIM)
)
self.task_gate = nn.Sequential(
nn.Linear(config.EMBED_DIM, config.EMBED_DIM),
nn.Sigmoid()
)
def forward(self, x, task):
h = self.norm1(x)
h = self.attn(h, h, h)[0]
g = self.task_gate(task).unsqueeze(1)
x = x + h * g
h = self.norm2(x)
h = self.mlp(h)
x = x + h * g
return x
class WaveletTemporal(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.EMBED_DIM
self.spatial_proj = nn.Conv3d(1, config.EMBED_DIM, 1)
self.temporal_proj = nn.Conv3d(
config.EMBED_DIM,
config.EMBED_DIM,
(3,1,1),
padding=(1,0,0)
)
self.pool = nn.AdaptiveAvgPool3d((15, 32, 32))
def forward(self, x):
b, t, h, d, w = x.shape
x = x.reshape(b, 1, t, h, w*d)
x = self.spatial_proj(x)
x = self.temporal_proj(x)
return self.pool(x)
class SequentialBrainViT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.temporal = WaveletTemporal(config)
self.pool = nn.Sequential(
nn.LayerNorm([config.EMBED_DIM, 15, 32, 32]),
nn.AdaptiveAvgPool3d((5, 16, 16)),
Rearrange('b c t h w -> b (t h w) c')
)
self.num_patches = 5 * 16 * 16
self.task_embed = nn.Embedding(4, config.TASK_DIM)
self.task_proj = nn.Sequential(
nn.Linear(config.TASK_DIM, config.EMBED_DIM),
nn.LayerNorm(config.EMBED_DIM),
nn.GELU()
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.EMBED_DIM))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.EMBED_DIM))
self.blocks = nn.ModuleList([
TransformerBlock(config)
for _ in range(config.NUM_LAYERS)
])
self.shared_proj = nn.Sequential(
nn.LayerNorm(config.EMBED_DIM),
nn.Linear(config.EMBED_DIM, config.EMBED_DIM * 2),
nn.GELU(),
nn.Linear(config.EMBED_DIM * 2, config.EMBED_DIM),
nn.LayerNorm(config.EMBED_DIM),
nn.Dropout(config.DROPOUT)
)
self.heads = nn.ModuleDict({
'learning_stage': nn.Sequential(
nn.LayerNorm(config.EMBED_DIM),
nn.Linear(config.EMBED_DIM, 1),
nn.Sigmoid()
),
'region_activation': nn.Sequential(
nn.LayerNorm(config.EMBED_DIM),
nn.Linear(config.EMBED_DIM, 116)
),
'temporal_pattern': nn.Sequential(
nn.LayerNorm(config.EMBED_DIM),
nn.Linear(config.EMBED_DIM, 30)
)
})
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x, task_ids):
x = self.temporal(x)
x = self.pool(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed[:,:x.shape[1]]
task = self.task_proj(self.task_embed(task_ids))
for block in self.blocks:
x = block(x, task)
x = self.shared_proj(x)
return {
'learning_stage': self.heads['learning_stage'](x[:,0]),
'region_activation': self.heads['region_activation'](x.mean(1)),
'temporal_pattern': self.heads['temporal_pattern'](x[:,0])
}
def preprocess_volume(vol, target_size=(64, 64, 30)):
if vol.ndim == 4:
vol = vol[None]
b,t,h,w,d = vol.shape
target_h, target_w, target_d = target_size
vol = zoom(vol, (
1, 1,
target_h/h,
target_w/w,
target_d/d
), order=1)
vol = (vol - vol.mean((1,2,3,4), keepdims=True)) / (vol.std((1,2,3,4), keepdims=True) + 1e-8)
return torch.from_numpy(vol).float()
def plot_brain_slices(data, learning_stage):
fig = plt.figure(figsize=(15, 5))
mean_activation = data.mean(axis=0)
for i, slice_idx in enumerate([mean_activation.shape[-1]//4,
mean_activation.shape[-1]//2,
3*mean_activation.shape[-1]//4]):
plt.subplot(1, 3, i+1)
plt.imshow(mean_activation[...,slice_idx].T, cmap='hot')
plt.colorbar()
plt.title(f'slice z={slice_idx}\nlearning: {learning_stage:.3f}')
plt.axis('off')
return fig
def plot_results(data, region_acts, temporal_pattern, learning_stage):
fig = plt.figure(figsize=(15,10))
gs = gridspec.GridSpec(2, 2)
# brain slices
ax1 = plt.subplot(gs[0,:])
mean_activation = data.mean(axis=0)
slice_idx = mean_activation.shape[-1]//2
brain_slice = mean_activation[...,slice_idx]
# find most active region
peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape)
im = ax1.imshow(brain_slice.T, cmap='hot')
plt.colorbar(im, ax=ax1)
ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15,
label=f'peak ({peak_coords[0]}, {peak_coords[1]})')
ax1.legend()
ax1.set_title(f'brain activation (z={slice_idx})\nlearning stage: {learning_stage:.3f}')
# region activations
ax2 = plt.subplot(gs[1,0])
max_region = np.argmax(region_acts)
sns.heatmap(region_acts.reshape(1,-1), cmap='RdBu_r', center=0, ax=ax2)
ax2.set_title(f'region activations\nmost active: {max_region}')
ax2.set_xlabel('brain region')
# temporal pattern
ax3 = plt.subplot(gs[1,1])
ax3.plot(temporal_pattern.squeeze())
ax3.set_title('temporal dynamics')
ax3.set_xlabel('time')
plt.tight_layout()
return fig
def process_fmri(file_obj):
try:
img = nib.load(file_obj.name)
data = img.get_fdata(dtype=np.float32)
# shape validation + expansion
if data.ndim == 3:
data = data[None,...] # [H,W,D] -> [T,H,W,D]
elif data.ndim != 4:
return f"error: volume must be 3D/4D, got {data.ndim}D", None
# validate dims
t,h,w,d = data.shape
if t < 1 or h < 16 or w < 16 or d < 8:
return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
if t > 1000 or h > 256 or w > 256 or d > 256:
return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
# reshape for batch
data = data.reshape(1, t, h, w, d) # [B,T,H,W,D]
# normalize + preprocess
data = preprocess_volume(data)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
results = {}
figs = []
for stage in ['full', 'region', 'temporal']:
try:
model = SequentialBrainViT(Config())
model._init_weights() # critical: init before load
ckpt = torch.load(f'best_{stage}.pt', map_location=device)
missing = model.load_state_dict(ckpt['model'], strict=False)
if missing:
print(f"warning - {stage} missing keys:", missing)
model.eval()
with torch.no_grad():
outputs = model(data.to(device), torch.tensor([0]).to(device))
results[stage] = {
'learning_stage': float(outputs['learning_stage'].cpu().mean()),
'region_activation': outputs['region_activation'].cpu().numpy(),
'temporal_pattern': outputs['temporal_pattern'].cpu().numpy()
}
fig = plot_results(
data[0].cpu().numpy(), # drop batch
results[stage]['region_activation'],
results[stage]['temporal_pattern'],
results[stage]['learning_stage']
)
figs.append(fig)
plt.close()
except Exception as e:
return f"error in {stage} model: {str(e)}", None
# enhanced results text w/ peak info
stage_results = "\n".join([
f"{stage.upper()} MODEL:"
f"\nlearning stage: {res['learning_stage']:.3f}"
f"\npeak region: {np.argmax(res['region_activation'])}"
f"\npeak activation: {np.max(res['region_activation']):.3f}"
f"\n"
for stage, res in results.items()
])
return stage_results, figs[0]
except Exception as e:
return f"error processing file: {str(e)}", None
iface = gr.Interface(
fn=process_fmri,
inputs=gr.File(label="upload 4D fMRI nifti (.nii/.nii.gz)"),
outputs=[
gr.Textbox(label="classification results"),
gr.Plot(label="brain activation + analysis")
],
title="fmri learning stage classifier",
description="upload a 4D fMRI nifti file to classify learning stages and visualize brain patterns",
examples=[],
cache_examples=False
)
if __name__ == "__main__":
iface.launch()