twarner commited on
Commit
0ffa42f
·
1 Parent(s): 2d41217
Files changed (5) hide show
  1. app.py +283 -4
  2. best_full.pt +3 -0
  3. best_region.pt +3 -0
  4. best_temporal.pt +3 -0
  5. requirements.txt +8 -0
app.py CHANGED
@@ -1,7 +1,286 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
  import gradio as gr
5
+ import numpy as np
6
+ import nibabel as nib
7
+ from pathlib import Path
8
+ from dataclasses import dataclass
9
+ from typing import Dict, List, Tuple, Optional
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from einops.layers.torch import Rearrange
14
+ from scipy.ndimage import zoom
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
 
18
+ # core config
19
+ @dataclass
20
+ class Config:
21
+ VOLUME_SIZE: Tuple[int, int, int] = (64, 64, 30)
22
+ EMBED_DIM: int = 256
23
+ NUM_HEADS: int = 8
24
+ NUM_LAYERS: int = 6
25
+ DROPOUT: float = 0.1
26
+ TASK_DIM: int = 512
27
 
28
+ # model components
29
+ class HierarchicalAttention(nn.Module):
30
+ def __init__(self, dim, heads=8):
31
+ super().__init__()
32
+ self.local_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
33
+ self.global_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
34
+ self.merge = nn.Linear(dim * 2, dim)
35
+ self.task_gate = nn.Sequential(
36
+ nn.Linear(dim, dim),
37
+ nn.Sigmoid()
38
+ )
39
+
40
+ def forward(self, x, task_embed=None):
41
+ local_out = self.local_attn(x, x, x)[0]
42
+ if task_embed is not None:
43
+ x = x * self.task_gate(task_embed).unsqueeze(1)
44
+ global_out = self.global_attn(x, x, x)[0]
45
+ return self.merge(torch.cat([local_out, global_out], dim=-1))
46
+
47
+ class TransformerBlock(nn.Module):
48
+ def __init__(self, config):
49
+ super().__init__()
50
+ self.norm1 = nn.LayerNorm(config.EMBED_DIM)
51
+ self.attn = nn.MultiheadAttention(
52
+ config.EMBED_DIM,
53
+ config.NUM_HEADS,
54
+ dropout=config.DROPOUT,
55
+ batch_first=True
56
+ )
57
+
58
+ self.norm2 = nn.LayerNorm(config.EMBED_DIM)
59
+ self.mlp = nn.Sequential(
60
+ nn.Linear(config.EMBED_DIM, config.EMBED_DIM * 4),
61
+ nn.GELU(),
62
+ nn.Dropout(config.DROPOUT),
63
+ nn.Linear(config.EMBED_DIM * 4, config.EMBED_DIM)
64
+ )
65
+
66
+ self.task_gate = nn.Sequential(
67
+ nn.Linear(config.EMBED_DIM, config.EMBED_DIM),
68
+ nn.Sigmoid()
69
+ )
70
+
71
+ def forward(self, x, task):
72
+ h = self.norm1(x)
73
+ h = self.attn(h, h, h)[0]
74
+ g = self.task_gate(task).unsqueeze(1)
75
+ x = x + h * g
76
+
77
+ h = self.norm2(x)
78
+ h = self.mlp(h)
79
+ x = x + h * g
80
+ return x
81
+
82
+ class WaveletTemporal(nn.Module):
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ self.embed_dim = config.EMBED_DIM
86
+ self.spatial_proj = nn.Conv3d(1, config.EMBED_DIM, 1)
87
+ self.temporal_proj = nn.Conv3d(
88
+ config.EMBED_DIM,
89
+ config.EMBED_DIM,
90
+ (3,1,1),
91
+ padding=(1,0,0)
92
+ )
93
+ self.pool = nn.AdaptiveAvgPool3d((15, 32, 32))
94
+
95
+ def forward(self, x):
96
+ b, t, h, d, w = x.shape
97
+ x = x.reshape(b, 1, t, h, w*d)
98
+ x = self.spatial_proj(x)
99
+ x = self.temporal_proj(x)
100
+ return self.pool(x)
101
+
102
+ class SequentialBrainViT(nn.Module):
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.config = config
106
+
107
+ self.temporal = WaveletTemporal(config)
108
+
109
+ self.pool = nn.Sequential(
110
+ nn.LayerNorm([config.EMBED_DIM, 15, 32, 32]),
111
+ nn.AdaptiveAvgPool3d((5, 16, 16)),
112
+ Rearrange('b c t h w -> b (t h w) c')
113
+ )
114
+
115
+ self.num_patches = 5 * 16 * 16
116
+
117
+ self.task_embed = nn.Embedding(4, config.TASK_DIM)
118
+ self.task_proj = nn.Sequential(
119
+ nn.Linear(config.TASK_DIM, config.EMBED_DIM),
120
+ nn.LayerNorm(config.EMBED_DIM),
121
+ nn.GELU()
122
+ )
123
+
124
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.EMBED_DIM))
125
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.EMBED_DIM))
126
+
127
+ self.blocks = nn.ModuleList([
128
+ TransformerBlock(config)
129
+ for _ in range(config.NUM_LAYERS)
130
+ ])
131
+
132
+ self.shared_proj = nn.Sequential(
133
+ nn.LayerNorm(config.EMBED_DIM),
134
+ nn.Linear(config.EMBED_DIM, config.EMBED_DIM * 2),
135
+ nn.GELU(),
136
+ nn.Linear(config.EMBED_DIM * 2, config.EMBED_DIM),
137
+ nn.LayerNorm(config.EMBED_DIM),
138
+ nn.Dropout(config.DROPOUT)
139
+ )
140
+
141
+ self.heads = nn.ModuleDict({
142
+ 'learning_stage': nn.Sequential(
143
+ nn.LayerNorm(config.EMBED_DIM),
144
+ nn.Linear(config.EMBED_DIM, 1),
145
+ nn.Sigmoid()
146
+ ),
147
+ 'region_activation': nn.Sequential(
148
+ nn.LayerNorm(config.EMBED_DIM),
149
+ nn.Linear(config.EMBED_DIM, 116)
150
+ ),
151
+ 'temporal_pattern': nn.Sequential(
152
+ nn.LayerNorm(config.EMBED_DIM),
153
+ nn.Linear(config.EMBED_DIM, 30)
154
+ )
155
+ })
156
+
157
+ self._init_weights()
158
+
159
+ def _init_weights(self):
160
+ nn.init.normal_(self.cls_token, std=0.02)
161
+ nn.init.normal_(self.pos_embed, std=0.02)
162
+
163
+ for n, m in self.named_modules():
164
+ if isinstance(m, nn.Linear):
165
+ nn.init.normal_(m.weight, std=0.02)
166
+ if m.bias is not None:
167
+ nn.init.zeros_(m.bias)
168
+ elif isinstance(m, nn.LayerNorm):
169
+ nn.init.ones_(m.weight)
170
+ nn.init.zeros_(m.bias)
171
+
172
+ def forward(self, x, task_ids):
173
+ x = self.temporal(x)
174
+ x = self.pool(x)
175
+
176
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
177
+ x = torch.cat([cls_tokens, x], dim=1)
178
+ x = x + self.pos_embed[:,:x.shape[1]]
179
+
180
+ task = self.task_proj(self.task_embed(task_ids))
181
+
182
+ for block in self.blocks:
183
+ x = block(x, task)
184
+ x = self.shared_proj(x)
185
+
186
+ return {
187
+ 'learning_stage': self.heads['learning_stage'](x[:,0]),
188
+ 'region_activation': self.heads['region_activation'](x.mean(1)),
189
+ 'temporal_pattern': self.heads['temporal_pattern'](x[:,0])
190
+ }
191
+
192
+ def preprocess_volume(vol, target_size=(64, 64, 30)):
193
+ if vol.ndim == 4:
194
+ vol = vol[None]
195
+
196
+ b,t,h,w,d = vol.shape
197
+ target_h, target_w, target_d = target_size
198
+
199
+ vol = zoom(vol, (
200
+ 1, 1,
201
+ target_h/h,
202
+ target_w/w,
203
+ target_d/d
204
+ ), order=1)
205
+
206
+ vol = (vol - vol.mean((1,2,3,4), keepdims=True)) / (vol.std((1,2,3,4), keepdims=True) + 1e-8)
207
+ return torch.from_numpy(vol).float()
208
+
209
+ def plot_results(region_acts, temporal_pattern):
210
+ fig = plt.figure(figsize=(12,4))
211
+
212
+ plt.subplot(121)
213
+ sns.heatmap(region_acts.reshape(1,-1), cmap='RdBu_r', center=0)
214
+ plt.title('region activations')
215
+ plt.xlabel('brain region')
216
+
217
+ plt.subplot(122)
218
+ plt.plot(temporal_pattern.squeeze())
219
+ plt.title('temporal pattern')
220
+ plt.xlabel('time')
221
+
222
+ return fig
223
+
224
+ def process_fmri(file_obj):
225
+ try:
226
+ img = nib.load(file_obj.name)
227
+ data = img.get_fdata(dtype=np.float32)
228
+
229
+ if data.ndim != 4:
230
+ return f"error: expected 4D data, got {data.ndim}D", None
231
+
232
+ data = preprocess_volume(data)
233
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
234
+
235
+ results = {}
236
+ figs = []
237
+
238
+ for stage in ['full', 'region', 'temporal']:
239
+ model = SequentialBrainViT(Config())
240
+ ckpt = torch.load(f'best_{stage}.pt', map_location=device)
241
+ model.load_state_dict(ckpt['model'])
242
+ model.eval()
243
+
244
+ with torch.no_grad():
245
+ outputs = model(data.to(device), torch.tensor([0]).to(device))
246
+ results[stage] = {
247
+ 'learning_stage': float(outputs['learning_stage'].cpu().mean()),
248
+ 'region_activation': outputs['region_activation'].cpu().numpy(),
249
+ 'temporal_pattern': outputs['temporal_pattern'].cpu().numpy()
250
+ }
251
+
252
+ fig = plot_results(
253
+ results[stage]['region_activation'],
254
+ results[stage]['temporal_pattern']
255
+ )
256
+ figs.append(fig)
257
+ plt.close()
258
+
259
+ stage_results = "\n".join([
260
+ f"{stage.upper()} MODEL:"
261
+ f"\nlearning stage: {res['learning_stage']:.3f}"
262
+ f"\n"
263
+ for stage, res in results.items()
264
+ ])
265
+
266
+ return stage_results, figs[0] # return first fig for display
267
+
268
+ except Exception as e:
269
+ return f"error processing file: {str(e)}", None
270
+
271
+ # create interface
272
+ iface = gr.Interface(
273
+ fn=process_fmri,
274
+ inputs=gr.File(label="upload 4D fMRI nifti (.nii/.nii.gz)"),
275
+ outputs=[
276
+ gr.Textbox(label="classification results"),
277
+ gr.Plot(label="visualization")
278
+ ],
279
+ title="fmri learning stage classifier",
280
+ description="upload a 4D fMRI nifti file to classify learning stages and visualize brain patterns",
281
+ examples=[],
282
+ cache_examples=False
283
+ )
284
+
285
+ if __name__ == "__main__":
286
+ iface.launch()
best_full.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57930a221cb9b9bc05ddc25ff50a2733e5c52a42e727a1b7774c39e1aaeafdab
3
+ size 132052666
best_region.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6facb050f7683527d0135de413fa5a1dc5da5ac98c1062f81a9df9358329466f
3
+ size 56936432
best_temporal.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bedd978b4ec32be4b5454368400c93addad6e42598060ef5ac13ec4cd0995be
3
+ size 56224950
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ numpy>=1.21.0
4
+ nibabel>=5.0.0
5
+ matplotlib>=3.5.0
6
+ seaborn>=0.12.0
7
+ einops>=0.6.0
8
+ scipy>=1.9.0