| import os |
| import pickle |
| import warnings |
| import tempfile |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from math import sqrt |
| import gradio as gr |
| import nibabel as nib |
| import base64 |
| import io |
| from PIL import Image |
| from sklearn.preprocessing import MinMaxScaler |
|
|
| |
| |
| |
|
|
| class Sine(nn.Module): |
| def __init__(self, w0=1.0): |
| super().__init__() |
| self.w0 = w0 |
| def forward(self, x): |
| return torch.sin(self.w0 * x) |
|
|
| class SirenLayer(nn.Module): |
| def __init__(self, dim_in, dim_out, w0=30.0, c=6.0, |
| is_first=False, use_bias=True, activation=None): |
| super().__init__() |
| self.linear = nn.Linear(dim_in, dim_out, bias=use_bias) |
| w_std = (1.0 / dim_in) if is_first else (sqrt(c / dim_in) / w0) |
| nn.init.uniform_(self.linear.weight, -w_std, w_std) |
| if use_bias: |
| nn.init.uniform_(self.linear.bias, -w_std, w_std) |
| self.activation = Sine(w0) if activation is None else activation |
| def forward(self, x): |
| return self.activation(self.linear(x)) |
|
|
| class Siren(nn.Module): |
| def __init__(self, dim_in, dim_hidden, dim_out, num_layers, |
| w0=30.0, w0_initial=30.0, use_bias=True, final_activation=None): |
| super().__init__() |
| layers = [] |
| for i in range(num_layers): |
| is_first = (i == 0) |
| layer_w0 = w0_initial if is_first else w0 |
| layer_in = dim_in if is_first else dim_hidden |
| layers.append(SirenLayer(layer_in, dim_hidden, w0=layer_w0, |
| use_bias=use_bias, is_first=is_first)) |
| self.net = nn.Sequential(*layers) |
| act = nn.Identity() if final_activation is None else final_activation |
| self.last_layer = SirenLayer(dim_hidden, dim_out, w0=w0, |
| use_bias=use_bias, activation=act) |
| def forward(self, x): |
| return self.last_layer(self.net(x)) |
|
|
| class SirenMRIModel(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.models = nn.ModuleList([ |
| Siren(dim_in=2, dim_hidden=config["layer_size"], |
| dim_out=config["vols"], num_layers=config["num_layers"], |
| w0=config["w0"], w0_initial=config["w0_initial"]) |
| for _ in range(config["sz"]) |
| ]) |
| def forward(self, coords, slice_idx): |
| return self.models[slice_idx](coords) |
|
|
| |
| |
| |
|
|
| def load_assets(): |
| model_path, scalers_path = "sirenMRI_full_model_final.pt", "scalers.pkl" |
| for p in (model_path, scalers_path): |
| if not os.path.exists(p): |
| raise FileNotFoundError(f"Missing: {p}") |
| ckpt = torch.load(model_path, map_location="cpu", weights_only=False) |
| cfg = ckpt["config"] |
| mdl = SirenMRIModel(cfg) |
| mdl.load_state_dict(ckpt["model_state_dict"]) |
| mdl.eval() |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| with open(scalers_path, "rb") as f: |
| scl = pickle.load(f) |
| return mdl, scl, cfg, ckpt["input_shape"] |
|
|
| print("β³ Loading modelβ¦") |
| model, scalers, config, input_shape = load_assets() |
| sx, sy, sz, vols = input_shape |
| print(f"β
Model ready β {sx}Γ{sy}Γ{sz}, {vols} volumes") |
|
|
| |
| |
| |
|
|
| def to_uint8(arr): |
| a, b = arr.min(), arr.max() |
| return ((arr - a) / (b - a + 1e-8) * 255).astype(np.uint8) |
|
|
| def to_coords(h, w): |
| xs = torch.linspace(-1, 1, h) |
| ys = torch.linspace(-1, 1, w) |
| gx, gy = torch.meshgrid(xs, ys, indexing="ij") |
| return torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=-1) |
|
|
| |
| |
| |
|
|
| def make_zoom_html(arr_uint8, title=""): |
| """Convert a uint8 numpy array to a self-contained zoomable HTML viewer.""" |
| pil_img = Image.fromarray(arr_uint8) |
| |
| w, h = pil_img.size |
| scale = max(1, 400 // max(w, h)) |
| pil_img = pil_img.resize((w * scale, h * scale), Image.NEAREST) |
| buf = io.BytesIO() |
| pil_img.save(buf, format="PNG") |
| b64 = base64.b64encode(buf.getvalue()).decode() |
| html = f""" |
| <div style="background:#f8f9ff;border:1.5px solid #ddd6fe;border-radius:14px; |
| padding:12px;user-select:none;"> |
| <div style="font-weight:800;color:#4c1d95;margin-bottom:8px;font-size:.95rem;"> |
| π {title} <span style="font-weight:500;color:#6b7280;font-size:.8rem;"> |
| Scroll to zoom Β· Drag to pan Β· Double-click to reset</span> |
| </div> |
| <div id="zoom-wrap-{hash(b64) & 0xffff}" |
| style="overflow:hidden;border-radius:10px;background:#000; |
| width:100%;height:420px;cursor:grab;position:relative;"> |
| <img id="zoom-img-{hash(b64) & 0xffff}" |
| src="data:image/png;base64,{b64}" |
| style="transform-origin:0 0;transform:scale(1) translate(0px,0px); |
| image-rendering:pixelated;max-width:none; |
| width:100%;height:100%;object-fit:contain;display:block;" |
| draggable="false"/> |
| </div> |
| </div> |
| <script> |
| (function() {{ |
| const wid = '{hash(b64) & 0xffff}'; |
| const wrap = document.getElementById('zoom-wrap-' + wid); |
| const img = document.getElementById('zoom-img-' + wid); |
| if (!wrap || !img) return; |
| |
| let scale = 1, ox = 0, oy = 0; |
| let dragging = false, startX, startY, lastOx, lastOy; |
| const MIN = 0.5, MAX = 12; |
| |
| function apply() {{ |
| img.style.transform = `scale(${{scale}}) translate(${{ox}}px,${{oy}}px)`; |
| }} |
| |
| // Scroll to zoom |
| wrap.addEventListener('wheel', e => {{ |
| e.preventDefault(); |
| const rect = wrap.getBoundingClientRect(); |
| const mx = e.clientX - rect.left; |
| const my = e.clientY - rect.top; |
| const factor = e.deltaY < 0 ? 1.12 : 0.89; |
| const newScale = Math.min(MAX, Math.max(MIN, scale * factor)); |
| ox = mx / newScale - mx / scale + ox; |
| oy = my / newScale - my / scale + oy; |
| scale = newScale; |
| apply(); |
| }}, {{ passive: false }}); |
| |
| // Drag to pan |
| wrap.addEventListener('mousedown', e => {{ |
| dragging = true; wrap.style.cursor = 'grabbing'; |
| startX = e.clientX; startY = e.clientY; |
| lastOx = ox; lastOy = oy; |
| }}); |
| window.addEventListener('mousemove', e => {{ |
| if (!dragging) return; |
| ox = lastOx + (e.clientX - startX) / scale; |
| oy = lastOy + (e.clientY - startY) / scale; |
| apply(); |
| }}); |
| window.addEventListener('mouseup', () => {{ |
| dragging = false; wrap.style.cursor = 'grab'; |
| }}); |
| |
| // Double-click to reset |
| wrap.addEventListener('dblclick', () => {{ |
| scale = 1; ox = 0; oy = 0; apply(); |
| }}); |
| |
| // Touch support |
| let lastDist = null; |
| wrap.addEventListener('touchstart', e => {{ |
| if (e.touches.length === 1) {{ |
| dragging = true; |
| startX = e.touches[0].clientX; startY = e.touches[0].clientY; |
| lastOx = ox; lastOy = oy; |
| }} |
| }}, {{ passive: true }}); |
| wrap.addEventListener('touchmove', e => {{ |
| if (e.touches.length === 2) {{ |
| const d = Math.hypot( |
| e.touches[0].clientX - e.touches[1].clientX, |
| e.touches[0].clientY - e.touches[1].clientY); |
| if (lastDist) {{ scale = Math.min(MAX, Math.max(MIN, scale * d / lastDist)); apply(); }} |
| lastDist = d; |
| }} else if (e.touches.length === 1 && dragging) {{ |
| ox = lastOx + (e.touches[0].clientX - startX) / scale; |
| oy = lastOy + (e.touches[0].clientY - startY) / scale; |
| apply(); |
| }} |
| }}, {{ passive: true }}); |
| wrap.addEventListener('touchend', () => {{ dragging = false; lastDist = null; }}); |
| }})(); |
| </script> |
| """ |
| return html |
|
|
| |
| |
| |
|
|
| def reconstruct_pretrained(slice_idx, vol_idx): |
| slice_idx, vol_idx = int(slice_idx), int(vol_idx) |
| coords = to_coords(sx, sy) |
| with torch.no_grad(): |
| pred = model(coords, slice_idx).numpy() |
| scaler = scalers[slice_idx] |
| data_min = np.array(scaler.data_min_, dtype=np.float32) |
| data_max = np.array(scaler.data_max_, dtype=np.float32) |
| pred = pred * (data_max - data_min) + data_min |
| recon = pred.reshape(sx, sy, vols)[:, :, vol_idx] |
| img_min, img_max = recon.min(), recon.max() |
| stats = ( |
| f"π Shape: {recon.shape} | " |
| f"π Intensity: [{img_min:.3f}, {img_max:.3f}] | " |
| f"π§ Slice {slice_idx} | π‘ Volume {vol_idx}" |
| ) |
| html = make_zoom_html(to_uint8(recon), f"Reconstructed β Slice {slice_idx}, Volume {vol_idx}") |
| return html, stats |
|
|
| |
| |
| |
|
|
| def compress_and_compare(nifti_file, slice_idx, vol_idx, num_iters, lr): |
| if nifti_file is None: |
| return None, None, "β οΈ Please upload a NIfTI file first." |
|
|
| slice_idx = int(slice_idx) |
| vol_idx = int(vol_idx) |
| num_iters = int(num_iters) |
|
|
| try: |
| nii = nib.load(nifti_file.name) |
| img_data = nii.get_fdata().astype(np.float32) |
| except Exception as e: |
| return None, None, f"β Failed to load NIfTI: {e}" |
|
|
| |
| if img_data.ndim == 3: |
| img_data = img_data[..., np.newaxis] |
| if img_data.ndim != 4: |
| return None, None, "β Expected a 3D or 4D NIfTI file." |
|
|
| ux, uy, uz, uvols = img_data.shape |
| slice_idx = min(slice_idx, uz - 1) |
| vol_idx = min(vol_idx, uvols - 1) |
|
|
| |
| orig_slice = img_data[:, :, slice_idx, vol_idx] |
| orig_img = to_uint8(orig_slice) |
|
|
| |
| img_slice = np.transpose(img_data[:, :, slice_idx, :], (2, 0, 1)) |
| features = img_slice.reshape(uvols, -1).T |
|
|
| scaler_u = MinMaxScaler(feature_range=(0, 1)) |
| features_scaled = scaler_u.fit_transform(features).astype(np.float32) |
|
|
| siren_u = Siren(dim_in=2, dim_hidden=config["layer_size"], |
| dim_out=uvols, num_layers=config["num_layers"], |
| w0=config["w0"], w0_initial=config["w0_initial"]) |
| opt = torch.optim.Adam(siren_u.parameters(), lr=float(lr)) |
| loss_fn = nn.MSELoss() |
|
|
| coords_u = to_coords(ux, uy) |
| feat_t = torch.from_numpy(features_scaled) |
|
|
| siren_u.train() |
| losses = [] |
| for it in range(num_iters): |
| opt.zero_grad() |
| pred = siren_u(coords_u) |
| loss = loss_fn(pred, feat_t) |
| loss.backward() |
| opt.step() |
| losses.append(loss.item()) |
|
|
| |
| siren_u.eval() |
| with torch.no_grad(): |
| pred_np = siren_u(coords_u).numpy() |
|
|
| pred_inv = scaler_u.inverse_transform(pred_np) |
| recon_slice = pred_inv.reshape(ux, uy, uvols)[:, :, vol_idx] |
| recon_img = to_uint8(recon_slice) |
|
|
| |
| mse = np.mean((orig_slice - recon_slice) ** 2) |
| o_max = orig_slice.max() |
| psnr = 20 * np.log10(o_max / (np.sqrt(mse) + 1e-8)) if o_max > 0 else float("nan") |
| final_loss = losses[-1] if losses else float("nan") |
|
|
| stats = ( |
| f"π Image: {ux}Γ{uy}Γ{uz}, {uvols} volumes | " |
| f"π― Slice {slice_idx}, Volume {vol_idx}\n" |
| f"π Final loss: {final_loss:.6f} | " |
| f"π‘ PSNR: {psnr:.2f} dB | " |
| f"π Iterations: {num_iters}" |
| ) |
| orig_html = make_zoom_html(orig_img, f"Original β Slice {slice_idx}, Volume {vol_idx}") |
| recon_html = make_zoom_html(recon_img, f"SIREN Reconstruction β Slice {slice_idx}, Volume {vol_idx}") |
| return orig_html, recon_html, stats |
|
|
| |
| |
| |
|
|
| CSS = """ |
| /* ββ Base ββ */ |
| body, .gradio-container { |
| background: #ffffff !important; |
| color: #111827 !important; |
| font-family: 'Inter', 'Segoe UI', sans-serif !important; |
| } |
| |
| /* ββ Primary button ββ */ |
| .gr-button-primary, button.primary { |
| background: linear-gradient(135deg, #6366f1, #8b5cf6) !important; |
| border: none !important; |
| border-radius: 10px !important; |
| font-weight: 800 !important; |
| font-size: 1rem !important; |
| color: #ffffff !important; |
| letter-spacing: 0.4px; |
| transition: transform .15s, box-shadow .15s; |
| } |
| .gr-button-primary:hover, button.primary:hover { |
| transform: translateY(-2px); |
| box-shadow: 0 8px 25px rgba(99,102,241,0.35) !important; |
| } |
| |
| /* ββ Cards / panels ββ */ |
| .gr-panel, .gr-box, .gradio-group { |
| background: #f8f9ff !important; |
| border: 1.5px solid #ddd6fe !important; |
| border-radius: 14px !important; |
| } |
| |
| /* ββ Inputs ββ */ |
| .gr-input, input, textarea, .gr-slider { |
| background: #ffffff !important; |
| border: 1.5px solid #c4b5fd !important; |
| color: #111827 !important; |
| font-weight: 600 !important; |
| border-radius: 8px !important; |
| } |
| |
| /* ββ Labels ββ */ |
| label, .gr-label, span.label { |
| color: #4c1d95 !important; |
| font-weight: 700 !important; |
| font-size: 0.95rem !important; |
| } |
| |
| /* ββ Markdown headings ββ */ |
| .gr-markdown h1 { |
| background: linear-gradient(135deg, #6366f1, #7c3aed); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| font-size: 2.4rem !important; |
| font-weight: 900 !important; |
| margin-bottom: 4px !important; |
| } |
| .gr-markdown h2 { |
| color: #5b21b6 !important; |
| font-size: 1.15rem !important; |
| font-weight: 700 !important; |
| } |
| .gr-markdown h3 { |
| color: #4c1d95 !important; |
| font-weight: 800 !important; |
| font-size: 1.05rem !important; |
| border-bottom: 2px solid #ede9fe; |
| padding-bottom: 4px; |
| } |
| .gr-markdown p, .gr-markdown li { |
| color: #1f2937 !important; |
| font-weight: 600 !important; |
| font-size: 0.97rem !important; |
| line-height: 1.7 !important; |
| } |
| .gr-markdown strong { |
| color: #3730a3 !important; |
| font-weight: 800 !important; |
| } |
| .gr-markdown code { |
| background: #ede9fe !important; |
| color: #5b21b6 !important; |
| font-weight: 700 !important; |
| border-radius: 4px; |
| padding: 1px 5px; |
| } |
| .gr-markdown table { |
| border-collapse: collapse; |
| width: 100%; |
| margin-top: 8px; |
| } |
| .gr-markdown th { |
| background: #ede9fe !important; |
| color: #3730a3 !important; |
| font-weight: 800 !important; |
| padding: 8px 12px; |
| border: 1px solid #c4b5fd; |
| } |
| .gr-markdown td { |
| color: #1f2937 !important; |
| font-weight: 600 !important; |
| padding: 7px 12px; |
| border: 1px solid #e5e7eb; |
| } |
| |
| /* ββ Tabs ββ */ |
| .tab-nav button { |
| color: #6366f1 !important; |
| font-weight: 700 !important; |
| border-radius: 8px 8px 0 0 !important; |
| font-size: 0.95rem !important; |
| } |
| .tab-nav button.selected { |
| background: #ede9fe !important; |
| border-bottom: 3px solid #6366f1 !important; |
| color: #3730a3 !important; |
| } |
| |
| /* ββ Textbox output ββ */ |
| textarea { |
| color: #111827 !important; |
| font-weight: 700 !important; |
| background: #fafafa !important; |
| } |
| |
| /* ββ Divider ββ */ |
| hr { border-color: #ede9fe !important; } |
| |
| /* ββ Hide footer ββ */ |
| footer { display: none !important; } |
| """ |
|
|
| with gr.Blocks(css=CSS, title="SIREN MRI Compression") as demo: |
|
|
| |
| gr.Markdown(""" |
| # π§ Physics-Informed SIREN MRI Compression |
| ## Neural implicit representation for diffusion MRI |
| --- |
| """) |
|
|
| with gr.Tabs(): |
|
|
| |
| |
| |
| with gr.Tab("π¬ Explore Pretrained Model"): |
| gr.Markdown(""" |
| ### Explore the model trained on the MGH-1010 diffusion dataset |
| Adjust the sliders and click **Reconstruct** to visualise any slice and volume. |
| """) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| sl1 = gr.Slider(0, sz-1, value=sz//2, step=1, label=f"Axial Slice (0 β {sz-1})") |
| vl1 = gr.Slider(0, vols-1, value=0, step=1, label=f"Diffusion Volume (0 β {vols-1})") |
| btn1 = gr.Button("βΆ Reconstruct", variant="primary") |
| stats1 = gr.Textbox(label="Statistics", lines=2, interactive=False) |
|
|
| gr.Markdown(f""" |
| --- |
| **Model config** |
| | Parameter | Value | |
| |---|---| |
| | Type | `{config['model'].upper()}` | |
| | Layers | `{config['num_layers']}` | |
| | Hidden size | `{config['layer_size']}` | |
| | wβ | `{config['w0']}` | |
| | Slices | `{sz}` | |
| | Volumes | `{vols}` | |
| """) |
|
|
| with gr.Column(scale=2): |
| out1 = gr.HTML(label="Reconstructed Slice", elem_id="recon_img") |
|
|
| btn1.click(reconstruct_pretrained, |
| inputs=[sl1, vl1], |
| outputs=[out1, stats1]) |
|
|
| |
| |
| |
| with gr.Tab("π Upload & Compress Your Own MRI"): |
| gr.Markdown(""" |
| ### Upload your own diffusion MRI in NIfTI format |
| The app will fit a SIREN network to the selected slice on-the-fly and show you |
| **original vs reconstructed** side by side. |
| > β οΈ For speed, only the selected slice is fitted. Use more iterations for better quality. |
| """) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| nifti_upload = gr.File( |
| label="Upload NIfTI file (.nii or .nii.gz)", |
| file_types=[".nii", ".gz"], |
| ) |
| sl2 = gr.Slider(0, 200, value=50, step=1, label="Axial Slice") |
| vl2 = gr.Slider(0, 551, value=0, step=1, label="Diffusion Volume") |
| with gr.Row(): |
| n_iters = gr.Slider(100, 2000, value=500, step=100, |
| label="Training Iterations") |
| lr_inp = gr.Slider(1e-4, 1e-2, value=3e-4, step=1e-4, |
| label="Learning Rate") |
| btn2 = gr.Button("π Compress & Compare", variant="primary") |
| stats2 = gr.Textbox(label="Results", lines=3, interactive=False) |
|
|
| with gr.Column(scale=2): |
| with gr.Row(): |
| orig_img = gr.HTML(label="π· Original Slice") |
| recon_img = gr.HTML(label="π€ SIREN Reconstruction") |
|
|
| btn2.click(compress_and_compare, |
| inputs=[nifti_upload, sl2, vl2, n_iters, lr_inp], |
| outputs=[orig_img, recon_img, stats2]) |
|
|
| |
| |
| |
| with gr.Tab("βΉοΈ About"): |
| gr.Markdown(f""" |
| ## About this App |
| |
| **Physics-Informed SIREN MRI Compression** uses sinusoidal representation networks |
| (SIRENs) to learn a compact neural implicit representation of diffusion MRI data. |
| |
| ### How it works |
| 1. Each axial slice is represented by a small MLP with **sine activations** (SIREN) |
| 2. The network maps 2D spatial coordinates **(x, y) β signal intensities** across all diffusion volumes |
| 3. A **physics-informed loss** (Stejskal-Tanner constraint) regularises the network |
| 4. At inference time, coordinates are queried to reconstruct the full slice |
| |
| ### Key advantages |
| - ποΈ **High compression ratio** - one small network per slice replaces raw voxel data |
| - β‘ **Resolution-agnostic** - can reconstruct at any spatial resolution |
| - π¬ **Physics-aware** - diffusion signal constraints improve anatomical fidelity |
| - π§© **No codec artefacts** - continuous representation, no JPEG/JPEG2000 blocking |
| |
| ### Model trained on |
| [MGH-1010 Connectome Diffusion Microstructure Dataset](https://www.kaggle.com/datasets) |
| |
| | Property | Value | |
| |---|---| |
| | Architecture | `{config['model'].upper()}` | |
| | Layers | `{config['num_layers']}` | |
| | Hidden units | `{config['layer_size']}` | |
| | wβ | `{config['w0']}` | |
| | Spatial slices | `{sz}` | |
| | Diffusion volumes | `{vols}` | |
| | Training data shape | `{sx} Γ {sy} Γ {sz} Γ {vols}` | |
| |
| ### References |
| - Sitzmann et al. (2020) - *Implicit Neural Representations with Periodic Activation Functions* |
| - Stejskal & Tanner (1965) - *Spin diffusion measurements: spin echoes in the presence of a time-dependent field gradient* |
| """) |
|
|
| demo.launch() |