medimaging commited on
Commit
7ca1a2b
Β·
verified Β·
1 Parent(s): 39835c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -109
app.py CHANGED
@@ -1,25 +1,26 @@
1
  import os
2
  import pickle
3
  import warnings
 
4
  import numpy as np
5
  import torch
6
  import torch.nn as nn
7
  from math import sqrt
8
  import gradio as gr
 
 
9
 
10
  # ══════════════════════════════════════════════════════════════════════════════
11
- # 1. Model definitions (must exactly match training code)
12
  # ══════════════════════════════════════════════════════════════════════════════
13
 
14
  class Sine(nn.Module):
15
- def __init__(self, w0: float = 1.0):
16
  super().__init__()
17
  self.w0 = w0
18
-
19
  def forward(self, x):
20
  return torch.sin(self.w0 * x)
21
 
22
-
23
  class SirenLayer(nn.Module):
24
  def __init__(self, dim_in, dim_out, w0=30.0, c=6.0,
25
  is_first=False, use_bias=True, activation=None):
@@ -30,11 +31,9 @@ class SirenLayer(nn.Module):
30
  if use_bias:
31
  nn.init.uniform_(self.linear.bias, -w_std, w_std)
32
  self.activation = Sine(w0) if activation is None else activation
33
-
34
  def forward(self, x):
35
  return self.activation(self.linear(x))
36
 
37
-
38
  class Siren(nn.Module):
39
  def __init__(self, dim_in, dim_hidden, dim_out, num_layers,
40
  w0=30.0, w0_initial=30.0, use_bias=True, final_activation=None):
@@ -50,149 +49,313 @@ class Siren(nn.Module):
50
  act = nn.Identity() if final_activation is None else final_activation
51
  self.last_layer = SirenLayer(dim_hidden, dim_out, w0=w0,
52
  use_bias=use_bias, activation=act)
53
-
54
  def forward(self, x):
55
  return self.last_layer(self.net(x))
56
 
57
-
58
  class SirenMRIModel(nn.Module):
59
- """One Siren per axial slice, bundled into a single nn.Module."""
60
-
61
  def __init__(self, config):
62
  super().__init__()
63
  self.config = config
64
  self.models = nn.ModuleList([
65
- Siren(
66
- dim_in=2,
67
- dim_hidden=config["layer_size"],
68
- dim_out=config["vols"],
69
- num_layers=config["num_layers"],
70
- w0=config["w0"],
71
- w0_initial=config["w0_initial"],
72
- )
73
  for _ in range(config["sz"])
74
  ])
75
-
76
  def forward(self, coords, slice_idx):
77
  return self.models[slice_idx](coords)
78
 
79
-
80
  # ══════════════════════════════════════════════════════════════════════════════
81
- # 2. Load model + scalers at startup
82
  # ══════════════════════════════════════════════════════════════════════════════
83
 
84
  def load_assets():
85
- model_path = "sirenMRI_full_model_final.pt"
86
- scalers_path = "scalers.pkl"
87
-
88
- for path in (model_path, scalers_path):
89
- if not os.path.exists(path):
90
- raise FileNotFoundError(
91
- f"Required file not found: {path}\n"
92
- "Upload sirenMRI_full_model_final.pt and scalers.pkl "
93
- "to the root of your Space."
94
- )
95
-
96
- checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
97
- config = checkpoint["config"]
98
-
99
- model = SirenMRIModel(config)
100
- model.load_state_dict(checkpoint["model_state_dict"])
101
- model.eval()
102
-
103
  with warnings.catch_warnings():
104
  warnings.simplefilter("ignore")
105
  with open(scalers_path, "rb") as f:
106
- scalers = pickle.load(f)
 
107
 
108
- return model, scalers, config, checkpoint["input_shape"]
109
-
110
-
111
- print("Loading model…")
112
  model, scalers, config, input_shape = load_assets()
113
  sx, sy, sz, vols = input_shape
114
- print(f"Model loaded β€” shape: {sx}x{sy}x{sz}, volumes: {vols}")
115
-
116
 
117
  # ══════════════���═══════════════════════════════════════════════════════════════
118
- # 3. Inference
119
  # ══════════════════════════════════════════════════════════════════════════════
120
 
121
- def reconstruct(slice_idx: int, vol_idx: int):
122
- """Reconstruct one axial slice and return it as a displayable image."""
123
- slice_idx = int(slice_idx)
124
- vol_idx = int(vol_idx)
125
 
126
- xs = torch.linspace(-1, 1, sx)
127
- ys = torch.linspace(-1, 1, sy)
128
- grid_x, grid_y = torch.meshgrid(xs, ys, indexing="ij")
129
- coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=-1)
 
130
 
131
- with torch.no_grad():
132
- pred = model(coords, slice_idx).numpy() # (sx*sy, vols)
 
133
 
134
- # Inverse MinMaxScaler β€” pure numpy, no sklearn call
 
 
 
 
135
  scaler = scalers[slice_idx]
136
  data_min = np.array(scaler.data_min_, dtype=np.float32)
137
  data_max = np.array(scaler.data_max_, dtype=np.float32)
138
- pred = pred * (data_max - data_min) + data_min # (sx*sy, vols)
139
-
140
- img = pred.reshape(sx, sy, vols)[:, :, vol_idx] # (sx, sy)
141
-
142
- # Normalise to [0, 255] uint8 for display
143
- img_min, img_max = img.min(), img.max()
144
- img_norm = (img - img_min) / (img_max - img_min + 1e-8)
145
- img_uint8 = (img_norm * 255).astype(np.uint8)
146
-
147
- info = (
148
- f"Slice {slice_idx} | Volume {vol_idx} | "
149
- f"Shape: {img.shape} | "
150
- f"Intensity range: [{img_min:.3f}, {img_max:.3f}]"
151
  )
152
- return img_uint8, info
153
-
154
 
155
  # ══════════════════════════════════════════════════════════════════════════════
156
- # 4. Gradio UI
157
  # ══════════════════════════════════════════════════════════════════════════════
158
 
159
- with gr.Blocks(title="SIREN MRI Compression") as demo:
160
- gr.Markdown(
161
- """
162
- # 🧠 Physics-Informed SIREN MRI Compression
163
- Reconstruct diffusion MRI slices from the compressed SIREN neural representation.
164
- Select an axial slice and diffusion volume, then click **Reconstruct**.
165
- """
166
- )
167
 
168
- with gr.Row():
169
- slice_slider = gr.Slider(
170
- minimum=0, maximum=sz - 1, step=1, value=sz // 2,
171
- label=f"Axial Slice (0 – {sz - 1})"
172
- )
173
- vol_slider = gr.Slider(
174
- minimum=0, maximum=vols - 1, step=1, value=0,
175
- label=f"Diffusion Volume / b-value index (0 – {vols - 1})"
176
- )
177
-
178
- reconstruct_btn = gr.Button("β–Ά Reconstruct", variant="primary")
179
-
180
- with gr.Row():
181
- output_image = gr.Image(label="Reconstructed Slice", type="numpy")
182
- output_info = gr.Textbox(label="Info", lines=2, interactive=False)
183
-
184
- reconstruct_btn.click(
185
- fn=reconstruct,
186
- inputs=[slice_slider, vol_slider],
187
- outputs=[output_image, output_info],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  )
 
189
 
190
- gr.Markdown(
191
- f"**Model config** β€” Type: `{config['model'].upper()}` | "
192
- f"Layers: `{config['num_layers']}` | "
193
- f"Hidden size: `{config['layer_size']}` | "
194
- f"w0: `{config['w0']}` | "
195
- f"Slices: `{sz}` | Volumes: `{vols}`"
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  demo.launch()
 
1
  import os
2
  import pickle
3
  import warnings
4
+ import tempfile
5
  import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  from math import sqrt
9
  import gradio as gr
10
+ import nibabel as nib
11
+ from sklearn.preprocessing import MinMaxScaler
12
 
13
  # ══════════════════════════════════════════════════════════════════════════════
14
+ # 1. Model definitions
15
  # ══════════════════════════════════════════════════════════════════════════════
16
 
17
  class Sine(nn.Module):
18
+ def __init__(self, w0=1.0):
19
  super().__init__()
20
  self.w0 = w0
 
21
  def forward(self, x):
22
  return torch.sin(self.w0 * x)
23
 
 
24
  class SirenLayer(nn.Module):
25
  def __init__(self, dim_in, dim_out, w0=30.0, c=6.0,
26
  is_first=False, use_bias=True, activation=None):
 
31
  if use_bias:
32
  nn.init.uniform_(self.linear.bias, -w_std, w_std)
33
  self.activation = Sine(w0) if activation is None else activation
 
34
  def forward(self, x):
35
  return self.activation(self.linear(x))
36
 
 
37
  class Siren(nn.Module):
38
  def __init__(self, dim_in, dim_hidden, dim_out, num_layers,
39
  w0=30.0, w0_initial=30.0, use_bias=True, final_activation=None):
 
49
  act = nn.Identity() if final_activation is None else final_activation
50
  self.last_layer = SirenLayer(dim_hidden, dim_out, w0=w0,
51
  use_bias=use_bias, activation=act)
 
52
  def forward(self, x):
53
  return self.last_layer(self.net(x))
54
 
 
55
  class SirenMRIModel(nn.Module):
 
 
56
  def __init__(self, config):
57
  super().__init__()
58
  self.config = config
59
  self.models = nn.ModuleList([
60
+ Siren(dim_in=2, dim_hidden=config["layer_size"],
61
+ dim_out=config["vols"], num_layers=config["num_layers"],
62
+ w0=config["w0"], w0_initial=config["w0_initial"])
 
 
 
 
 
63
  for _ in range(config["sz"])
64
  ])
 
65
  def forward(self, coords, slice_idx):
66
  return self.models[slice_idx](coords)
67
 
 
68
  # ══════════════════════════════════════════════════════════════════════════════
69
+ # 2. Load pretrained model
70
  # ══════════════════════════════════════════════════════════════════════════════
71
 
72
  def load_assets():
73
+ model_path, scalers_path = "sirenMRI_full_model_final.pt", "scalers.pkl"
74
+ for p in (model_path, scalers_path):
75
+ if not os.path.exists(p):
76
+ raise FileNotFoundError(f"Missing: {p}")
77
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
78
+ cfg = ckpt["config"]
79
+ mdl = SirenMRIModel(cfg)
80
+ mdl.load_state_dict(ckpt["model_state_dict"])
81
+ mdl.eval()
 
 
 
 
 
 
 
 
 
82
  with warnings.catch_warnings():
83
  warnings.simplefilter("ignore")
84
  with open(scalers_path, "rb") as f:
85
+ scl = pickle.load(f)
86
+ return mdl, scl, cfg, ckpt["input_shape"]
87
 
88
+ print("⏳ Loading model…")
 
 
 
89
  model, scalers, config, input_shape = load_assets()
90
  sx, sy, sz, vols = input_shape
91
+ print(f"βœ… Model ready β€” {sx}Γ—{sy}Γ—{sz}, {vols} volumes")
 
92
 
93
  # ══════════════���═══════════════════════════════════════════════════════════════
94
+ # 3. Helper: normalise to uint8
95
  # ══════════════════════════════════════════════════════════════════════════════
96
 
97
+ def to_uint8(arr):
98
+ a, b = arr.min(), arr.max()
99
+ return ((arr - a) / (b - a + 1e-8) * 255).astype(np.uint8)
 
100
 
101
+ def to_coords(h, w):
102
+ xs = torch.linspace(-1, 1, h)
103
+ ys = torch.linspace(-1, 1, w)
104
+ gx, gy = torch.meshgrid(xs, ys, indexing="ij")
105
+ return torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=-1)
106
 
107
+ # ══════════════════════════════════════════════════════════════════════════════
108
+ # 4a. Reconstruct from pretrained model
109
+ # ══════════════════════════════════════════════════════════════════════════════
110
 
111
+ def reconstruct_pretrained(slice_idx, vol_idx):
112
+ slice_idx, vol_idx = int(slice_idx), int(vol_idx)
113
+ coords = to_coords(sx, sy)
114
+ with torch.no_grad():
115
+ pred = model(coords, slice_idx).numpy()
116
  scaler = scalers[slice_idx]
117
  data_min = np.array(scaler.data_min_, dtype=np.float32)
118
  data_max = np.array(scaler.data_max_, dtype=np.float32)
119
+ pred = pred * (data_max - data_min) + data_min
120
+ recon = pred.reshape(sx, sy, vols)[:, :, vol_idx]
121
+ img_min, img_max = recon.min(), recon.max()
122
+ stats = (
123
+ f"πŸ“ Shape: {recon.shape} | "
124
+ f"πŸ“Š Intensity: [{img_min:.3f}, {img_max:.3f}] | "
125
+ f"🧠 Slice {slice_idx} | πŸ“‘ Volume {vol_idx}"
 
 
 
 
 
 
126
  )
127
+ return to_uint8(recon), stats
 
128
 
129
  # ══════════════════════════════════════════════════════════════════════════════
130
+ # 4b. Compress & reconstruct user-uploaded NIfTI
131
  # ══════════════════════════════════════════════════════════════════════════════
132
 
133
+ def compress_and_compare(nifti_file, slice_idx, vol_idx, num_iters, lr):
134
+ if nifti_file is None:
135
+ return None, None, "⚠️ Please upload a NIfTI file first."
 
 
 
 
 
136
 
137
+ slice_idx = int(slice_idx)
138
+ vol_idx = int(vol_idx)
139
+ num_iters = int(num_iters)
140
+
141
+ try:
142
+ nii = nib.load(nifti_file.name)
143
+ img_data = nii.get_fdata().astype(np.float32)
144
+ except Exception as e:
145
+ return None, None, f"❌ Failed to load NIfTI: {e}"
146
+
147
+ # Handle 3D (single volume) or 4D
148
+ if img_data.ndim == 3:
149
+ img_data = img_data[..., np.newaxis]
150
+ if img_data.ndim != 4:
151
+ return None, None, "❌ Expected a 3D or 4D NIfTI file."
152
+
153
+ ux, uy, uz, uvols = img_data.shape
154
+ slice_idx = min(slice_idx, uz - 1)
155
+ vol_idx = min(vol_idx, uvols - 1)
156
+
157
+ # ── Original slice ────────────────────────────────────────────────────────
158
+ orig_slice = img_data[:, :, slice_idx, vol_idx]
159
+ orig_img = to_uint8(orig_slice)
160
+
161
+ # ── Quick SIREN fit on this one slice ─────────────────────────────────────
162
+ img_slice = np.transpose(img_data[:, :, slice_idx, :], (2, 0, 1)) # (vols, h, w)
163
+ features = img_slice.reshape(uvols, -1).T # (h*w, vols)
164
+
165
+ scaler_u = MinMaxScaler(feature_range=(0, 1))
166
+ features_scaled = scaler_u.fit_transform(features).astype(np.float32)
167
+
168
+ siren_u = Siren(dim_in=2, dim_hidden=config["layer_size"],
169
+ dim_out=uvols, num_layers=config["num_layers"],
170
+ w0=config["w0"], w0_initial=config["w0_initial"])
171
+ opt = torch.optim.Adam(siren_u.parameters(), lr=float(lr))
172
+ loss_fn = nn.MSELoss()
173
+
174
+ coords_u = to_coords(ux, uy)
175
+ feat_t = torch.from_numpy(features_scaled)
176
+
177
+ siren_u.train()
178
+ losses = []
179
+ for it in range(num_iters):
180
+ opt.zero_grad()
181
+ pred = siren_u(coords_u)
182
+ loss = loss_fn(pred, feat_t)
183
+ loss.backward()
184
+ opt.step()
185
+ losses.append(loss.item())
186
+
187
+ # ── Reconstruct ───────────────────────────────────────────────────────────
188
+ siren_u.eval()
189
+ with torch.no_grad():
190
+ pred_np = siren_u(coords_u).numpy()
191
+
192
+ pred_inv = scaler_u.inverse_transform(pred_np) # (h*w, vols)
193
+ recon_slice = pred_inv.reshape(ux, uy, uvols)[:, :, vol_idx]
194
+ recon_img = to_uint8(recon_slice)
195
+
196
+ # ── PSNR ──────────────────────────────────────────────────────────────────
197
+ mse = np.mean((orig_slice - recon_slice) ** 2)
198
+ o_max = orig_slice.max()
199
+ psnr = 20 * np.log10(o_max / (np.sqrt(mse) + 1e-8)) if o_max > 0 else float("nan")
200
+ final_loss = losses[-1] if losses else float("nan")
201
+
202
+ stats = (
203
+ f"πŸ“ Image: {ux}Γ—{uy}Γ—{uz}, {uvols} volumes | "
204
+ f"🎯 Slice {slice_idx}, Volume {vol_idx}\n"
205
+ f"πŸ“‰ Final loss: {final_loss:.6f} | "
206
+ f"πŸ“‘ PSNR: {psnr:.2f} dB | "
207
+ f"πŸ” Iterations: {num_iters}"
208
  )
209
+ return orig_img, recon_img, stats
210
 
211
+ # ══════════════════════════════════════════════════════════════════════════════
212
+ # 5. Gradio UI
213
+ # ══════════════════════════════════════════════════════════════════════════════
214
+
215
+ CSS = """
216
+ :root { --primary: #6366f1; --bg: #0f0f1a; --card: #1a1a2e; --border: #2d2d4e; }
217
+ body, .gradio-container { background: var(--bg) !important; color: #e2e8f0 !important; }
218
+ .gr-button-primary { background: linear-gradient(135deg,#6366f1,#8b5cf6) !important;
219
+ border: none !important; border-radius: 10px !important; font-weight: 700 !important;
220
+ letter-spacing: .5px; transition: transform .15s, box-shadow .15s; }
221
+ .gr-button-primary:hover { transform: translateY(-2px);
222
+ box-shadow: 0 8px 25px rgba(99,102,241,.45) !important; }
223
+ .gr-panel, .gr-box { background: var(--card) !important;
224
+ border: 1px solid var(--border) !important; border-radius: 14px !important; }
225
+ .gr-input, .gr-slider { background: #12122a !important; border-color: var(--border) !important; }
226
+ label { color: #a5b4fc !important; font-weight: 600 !important; }
227
+ .gr-markdown h1 { background: linear-gradient(135deg,#6366f1,#a78bfa);
228
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;
229
+ font-size: 2.2rem !important; font-weight: 800 !important; }
230
+ .gr-markdown h2 { color: #a5b4fc !important; font-size: 1.1rem !important; }
231
+ .tab-nav button { color: #a5b4fc !important; border-radius: 8px 8px 0 0 !important; }
232
+ .tab-nav button.selected { background: var(--card) !important;
233
+ border-bottom: 2px solid #6366f1 !important; color: #fff !important; }
234
+ footer { display: none !important; }
235
+ """
236
+
237
+ with gr.Blocks(css=CSS, title="SIREN MRI Compression") as demo:
238
+
239
+ # ── Header ────────────────────────────────────────────────────────────────
240
+ gr.Markdown("""
241
+ # 🧠 Physics-Informed SIREN MRI Compression
242
+ ## Neural implicit representation for diffusion MRI β€” compress, reconstruct & compare
243
+ ---
244
+ """)
245
+
246
+ with gr.Tabs():
247
+
248
+ # ══════════════════════════════════════════════════════════════════════
249
+ # TAB 1 β€” Pretrained model explorer
250
+ # ══════════════════════════════════════════════════════════════════════
251
+ with gr.Tab("πŸ”¬ Explore Pretrained Model"):
252
+ gr.Markdown("""
253
+ ### Explore the model trained on the MGH-1010 diffusion dataset
254
+ Adjust the sliders and click **Reconstruct** to visualise any slice and volume.
255
+ """)
256
+ with gr.Row():
257
+ with gr.Column(scale=1):
258
+ sl1 = gr.Slider(0, sz-1, value=sz//2, step=1, label=f"Axial Slice (0 – {sz-1})")
259
+ vl1 = gr.Slider(0, vols-1, value=0, step=1, label=f"Diffusion Volume (0 – {vols-1})")
260
+ btn1 = gr.Button("β–Ά Reconstruct", variant="primary")
261
+ stats1 = gr.Textbox(label="Statistics", lines=2, interactive=False)
262
+
263
+ gr.Markdown(f"""
264
+ ---
265
+ **Model config**
266
+ | Parameter | Value |
267
+ |---|---|
268
+ | Type | `{config['model'].upper()}` |
269
+ | Layers | `{config['num_layers']}` |
270
+ | Hidden size | `{config['layer_size']}` |
271
+ | wβ‚€ | `{config['w0']}` |
272
+ | Slices | `{sz}` |
273
+ | Volumes | `{vols}` |
274
+ """)
275
+
276
+ with gr.Column(scale=2):
277
+ out1 = gr.Image(label="Reconstructed Slice", type="numpy",
278
+ elem_id="recon_img", height=420)
279
+
280
+ btn1.click(reconstruct_pretrained,
281
+ inputs=[sl1, vl1],
282
+ outputs=[out1, stats1])
283
+
284
+ # ══════════════════════════════════════════════════════════════════════
285
+ # TAB 2 β€” Upload your own NIfTI
286
+ # ══════════════════════════════════════════════════════════════════════
287
+ with gr.Tab("πŸ“‚ Upload & Compress Your Own MRI"):
288
+ gr.Markdown("""
289
+ ### Upload your own diffusion MRI in NIfTI format
290
+ The app will fit a SIREN network to the selected slice on-the-fly and show you
291
+ **original vs reconstructed** side by side.
292
+ > ⚠️ For speed, only the selected slice is fitted. Use more iterations for better quality.
293
+ """)
294
+ with gr.Row():
295
+ with gr.Column(scale=1):
296
+ nifti_upload = gr.File(
297
+ label="Upload NIfTI file (.nii or .nii.gz)",
298
+ file_types=[".nii", ".gz"],
299
+ )
300
+ sl2 = gr.Slider(0, 200, value=50, step=1, label="Axial Slice")
301
+ vl2 = gr.Slider(0, 551, value=0, step=1, label="Diffusion Volume")
302
+ with gr.Row():
303
+ n_iters = gr.Slider(100, 2000, value=500, step=100,
304
+ label="Training Iterations")
305
+ lr_inp = gr.Slider(1e-4, 1e-2, value=3e-4, step=1e-4,
306
+ label="Learning Rate")
307
+ btn2 = gr.Button("πŸš€ Compress & Compare", variant="primary")
308
+ stats2 = gr.Textbox(label="Results", lines=3, interactive=False)
309
+
310
+ with gr.Column(scale=2):
311
+ with gr.Row():
312
+ orig_img = gr.Image(label="πŸ“· Original Slice",
313
+ type="numpy", height=380)
314
+ recon_img = gr.Image(label="πŸ€– SIREN Reconstruction",
315
+ type="numpy", height=380)
316
+
317
+ btn2.click(compress_and_compare,
318
+ inputs=[nifti_upload, sl2, vl2, n_iters, lr_inp],
319
+ outputs=[orig_img, recon_img, stats2])
320
+
321
+ # ══════════════════════════════════════════════════════════════════════
322
+ # TAB 3 β€” About
323
+ # ══════════════════════════════════════════════════════════════════════
324
+ with gr.Tab("ℹ️ About"):
325
+ gr.Markdown(f"""
326
+ ## About this App
327
+
328
+ **Physics-Informed SIREN MRI Compression** uses sinusoidal representation networks
329
+ (SIRENs) to learn a compact neural implicit representation of diffusion MRI data.
330
+
331
+ ### How it works
332
+ 1. Each axial slice is represented by a small MLP with **sine activations** (SIREN)
333
+ 2. The network maps 2D spatial coordinates **(x, y) β†’ signal intensities** across all diffusion volumes
334
+ 3. A **physics-informed loss** (Stejskal-Tanner constraint) regularises the network
335
+ 4. At inference time, coordinates are queried to reconstruct the full slice
336
+
337
+ ### Key advantages
338
+ - πŸ—œοΈ **High compression ratio** β€” one small network per slice replaces raw voxel data
339
+ - ⚑ **Resolution-agnostic** β€” can reconstruct at any spatial resolution
340
+ - πŸ”¬ **Physics-aware** β€” diffusion signal constraints improve anatomical fidelity
341
+ - 🧩 **No codec artefacts** β€” continuous representation, no JPEG/JPEG2000 blocking
342
+
343
+ ### Model trained on
344
+ [MGH-1010 Connectome Diffusion Microstructure Dataset](https://www.kaggle.com/datasets)
345
+
346
+ | Property | Value |
347
+ |---|---|
348
+ | Architecture | `{config['model'].upper()}` |
349
+ | Layers | `{config['num_layers']}` |
350
+ | Hidden units | `{config['layer_size']}` |
351
+ | wβ‚€ | `{config['w0']}` |
352
+ | Spatial slices | `{sz}` |
353
+ | Diffusion volumes | `{vols}` |
354
+ | Training data shape | `{sx} Γ— {sy} Γ— {sz} Γ— {vols}` |
355
+
356
+ ### References
357
+ - Sitzmann et al. (2020) β€” *Implicit Neural Representations with Periodic Activation Functions*
358
+ - Stejskal & Tanner (1965) β€” *Spin diffusion measurements: spin echoes in the presence of a time-dependent field gradient*
359
+ """)
360
 
361
  demo.launch()