medimaging commited on
Commit
be0ae56
Β·
verified Β·
1 Parent(s): b9d9a7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -134
app.py CHANGED
@@ -1,164 +1,183 @@
1
  import os
2
- import gradio as gr
3
- import torch
4
  import pickle
5
  import numpy as np
6
- import nibabel as nib
7
-
8
- from siren import Siren
9
-
10
-
11
- # =========================
12
- # DEVICE
13
- # =========================
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
-
17
- # =========================
18
- # LOAD MODEL
19
- # =========================
20
- checkpoint = torch.load(
21
- "sirenMRI_full_model_final.pt",
22
- map_location=device
23
  )
24
 
25
- config = checkpoint["config"]
26
-
27
- model = Siren(
28
- dim_in=2,
29
- dim_hidden=config["layer_size"],
30
- dim_out=config["vols"],
31
- num_layers=config["num_layers"],
32
- w0=config["w0"],
33
- w0_initial=config["w0_initial"]
34
- ).to(device)
35
-
36
- model.load_state_dict(checkpoint["model_state_dict"])
37
- model.eval()
38
-
39
-
40
- # =========================
41
- # LOAD SCALERS
42
- # =========================
43
- with open("scalers.pkl", "rb") as f:
44
- scalers = pickle.load(f)
45
-
46
-
47
- # =========================
48
- # CREATE COORDINATES
49
- # =========================
50
- def create_coordinates(h, w):
51
-
52
- y_coords = np.linspace(-1, 1, h)
53
- x_coords = np.linspace(-1, 1, w)
54
-
55
- yy, xx = np.meshgrid(y_coords, x_coords, indexing='ij')
56
-
57
- coords = np.stack([yy, xx], axis=-1)
58
-
59
- coords = coords.reshape(-1, 2)
60
-
61
- return torch.tensor(coords, dtype=torch.float32)
62
-
63
-
64
- # =========================
65
- # RECONSTRUCTION FUNCTION
66
- # =========================
67
- def reconstruct(nifti_file):
68
-
69
- try:
70
-
71
- # =========================
72
- # LOAD MRI
73
- # =========================
74
- nii = nib.load(nifti_file.name)
75
- img = nii.get_fdata()
76
-
77
- original_shape = img.shape
78
-
79
- # =========================
80
- # HANDLE 4D MRI
81
- # =========================
82
- if len(img.shape) != 4:
83
- return "❌ Expected a 4D MRI (.nii or .nii.gz)", None
84
-
85
- h, w, slices, vols = img.shape
86
 
87
- reconstructed = np.zeros_like(img)
 
88
 
89
- # =========================
90
- # GENERATE COORDINATES
91
- # =========================
92
- coords = create_coordinates(h, w).to(device)
93
 
94
- # =========================
95
- # RECONSTRUCT EACH SLICE
96
- # =========================
97
- with torch.no_grad():
98
 
99
- for s in range(slices):
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- pred = model(coords)
 
102
 
103
- pred = pred.cpu().numpy()
104
 
105
- # inverse normalization
106
- pred = scalers[s].inverse_transform(pred)
 
107
 
108
- pred = pred.reshape(h, w, vols)
 
 
 
109
 
110
- reconstructed[:, :, s, :] = pred
 
 
 
 
 
111
 
112
- # =========================
113
- # SAVE OUTPUT MRI
114
- # =========================
115
- output_path = "reconstructed_output.nii.gz"
116
 
117
- recon_nii = nib.Nifti1Image(
118
- reconstructed,
119
- affine=nii.affine,
120
- header=nii.header
121
- )
122
 
123
- nib.save(recon_nii, output_path)
 
124
 
125
- return (
126
- f"""βœ… Reconstruction complete
127
 
128
- Original Shape: {original_shape}
129
 
130
- Reconstructed MRI saved successfully.
131
- """,
132
- output_path
133
- )
134
 
135
- except Exception as e:
136
- return f"❌ Error: {str(e)}", None
137
 
 
 
 
138
 
139
- # =========================
140
- # GRADIO UI
141
- # =========================
142
- interface = gr.Interface(
143
- fn=reconstruct,
 
 
144
 
145
- inputs=gr.File(
146
- label="Upload MRI (.nii or .nii.gz)"
147
- ),
148
 
149
- outputs=[
150
- gr.Textbox(label="Status"),
151
- gr.File(label="Download Reconstructed MRI")
152
- ],
153
 
154
- title="Physics-Informed SIREN MRI Reconstruction",
155
 
156
- description="""
157
- Upload a 4D MRI scan (.nii or .nii.gz).
 
158
 
159
- The model reconstructs the MRI using a trained
160
- Physics-Informed SIREN neural representation.
161
- """
 
162
  )
163
 
164
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
2
  import pickle
3
  import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from math import sqrt
7
+ import streamlit as st
8
+
9
+ # ── Streamlit page config ────────────────────────────────────────────────────
10
+ st.set_page_config(
11
+ page_title="Physics-Informed SIREN MRI Compression",
12
+ page_icon="🧠",
13
+ layout="wide",
 
 
 
 
 
 
 
14
  )
15
 
16
+ # ══════════════════════════════════════════════════════════════════════════════
17
+ # 1. Model definitions β€” must match exactly what was used during training
18
+ # ══════════════════════════════════════════════════════════════════════════════
19
+
20
+ class Sine(nn.Module):
21
+ def __init__(self, w0=1.0):
22
+ super().__init__()
23
+ self.w0 = w0
24
+
25
+ def forward(self, x):
26
+ return torch.sin(self.w0 * x)
27
+
28
+
29
+ class SirenLayer(nn.Module):
30
+ def __init__(self, dim_in, dim_out, w0=30.0, c=6.0,
31
+ is_first=False, use_bias=True, activation=None):
32
+ super().__init__()
33
+ self.dim_in = dim_in
34
+ self.is_first = is_first
35
+ self.linear = nn.Linear(dim_in, dim_out, bias=use_bias)
36
+ w_std = (1 / dim_in) if is_first else (sqrt(c / dim_in) / w0)
37
+ nn.init.uniform_(self.linear.weight, -w_std, w_std)
38
+ if use_bias:
39
+ nn.init.uniform_(self.linear.bias, -w_std, w_std)
40
+ self.activation = Sine(w0) if activation is None else activation
41
+
42
+ def forward(self, x):
43
+ return self.activation(self.linear(x))
44
+
45
+
46
+ class Siren(nn.Module):
47
+ def __init__(self, dim_in, dim_hidden, dim_out, num_layers,
48
+ w0=30.0, w0_initial=30.0, use_bias=True, final_activation=None):
49
+ super().__init__()
50
+ layers = []
51
+ for i in range(num_layers):
52
+ is_first = i == 0
53
+ layer_w0 = w0_initial if is_first else w0
54
+ layer_dim_in = dim_in if is_first else dim_hidden
55
+ layers.append(SirenLayer(
56
+ dim_in=layer_dim_in, dim_out=dim_hidden,
57
+ w0=layer_w0, use_bias=use_bias, is_first=is_first,
58
+ ))
59
+ self.net = nn.Sequential(*layers)
60
+ final_activation = nn.Identity() if final_activation is None else final_activation
61
+ self.last_layer = SirenLayer(
62
+ dim_in=dim_hidden, dim_out=dim_out,
63
+ w0=w0, use_bias=use_bias, activation=final_activation,
64
+ )
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def forward(self, x):
67
+ return self.last_layer(self.net(x))
68
 
 
 
 
 
69
 
70
+ class SirenMRIModel(nn.Module):
71
+ """Multi-slice wrapper β€” one Siren per axial slice."""
 
 
72
 
73
+ def __init__(self, config: dict):
74
+ super().__init__()
75
+ self.config = config
76
+ self.models = nn.ModuleList()
77
+ for _ in range(config["sz"]):
78
+ self.models.append(Siren(
79
+ dim_in=2,
80
+ dim_hidden=config["layer_size"],
81
+ dim_out=config["vols"],
82
+ num_layers=config["num_layers"],
83
+ w0=config["w0"],
84
+ w0_initial=config["w0_initial"],
85
+ ))
86
 
87
+ def forward(self, coords: torch.Tensor, slice_idx: int) -> torch.Tensor:
88
+ return self.models[slice_idx](coords)
89
 
 
90
 
91
+ # ══════════════════════════════════════════════════════════════════════════════
92
+ # 2. Load model + scalers (cached so they only load once)
93
+ # ══════════════════════════════════════════════════════════════════════════════
94
 
95
+ @st.cache_resource
96
+ def load_model():
97
+ model_path = "sirenMRI_full_model_final.pt"
98
+ scalers_path = "scalers.pkl"
99
 
100
+ if not os.path.exists(model_path):
101
+ st.error(f"Model file not found: {model_path}")
102
+ st.stop()
103
+ if not os.path.exists(scalers_path):
104
+ st.error(f"Scalers file not found: {scalers_path}")
105
+ st.stop()
106
 
107
+ checkpoint = torch.load(model_path, map_location="cpu")
108
+ config = checkpoint["config"]
 
 
109
 
110
+ # ── THE FIX: build SirenMRIModel, not bare Siren ──────────────────────
111
+ model = SirenMRIModel(config)
112
+ model.load_state_dict(checkpoint["model_state_dict"])
113
+ model.eval()
 
114
 
115
+ with open(scalers_path, "rb") as f:
116
+ scalers = pickle.load(f)
117
 
118
+ return model, scalers, config, checkpoint["input_shape"]
 
119
 
 
120
 
121
+ model, scalers, config, input_shape = load_model()
122
+ sx, sy, sz, vols = input_shape
 
 
123
 
 
 
124
 
125
+ # ══════════════════════════════════════════════════════════════════════════════
126
+ # 3. Inference helper
127
+ # ══════════════════════════════════════════════════════════════════════════════
128
 
129
+ def reconstruct_slice(slice_idx: int) -> np.ndarray:
130
+ """Return reconstructed slice as (sx, sy, vols) float32 array."""
131
+ # Build normalised (x, y) coordinate grid in [-1, 1]
132
+ xs = torch.linspace(-1, 1, sx)
133
+ ys = torch.linspace(-1, 1, sy)
134
+ grid_x, grid_y = torch.meshgrid(xs, ys, indexing="ij")
135
+ coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=-1) # (sx*sy, 2)
136
 
137
+ with torch.no_grad():
138
+ pred = model(coords, slice_idx).numpy() # (sx*sy, vols)
 
139
 
140
+ # Inverse-transform back to original intensity range
141
+ pred = scalers[slice_idx].inverse_transform(pred)
142
+ return pred.reshape(sx, sy, vols)
 
143
 
 
144
 
145
+ # ══════════════════════════════════════════════════════════════════════════════
146
+ # 4. Streamlit UI
147
+ # ══════════════════════════════════════════════════════════════════════════════
148
 
149
+ st.title("🧠 Physics-Informed SIREN MRI Compression")
150
+ st.markdown(
151
+ "Reconstruct axial slices from the compressed SIREN representation. "
152
+ "Use the sliders to choose a slice and a diffusion volume."
153
  )
154
 
155
+ col1, col2 = st.columns(2)
156
+ with col1:
157
+ slice_idx = st.slider("Axial slice", 0, sz - 1, sz // 2)
158
+ with col2:
159
+ vol_idx = st.slider("Diffusion volume (b-value index)", 0, vols - 1, 0)
160
+
161
+ if st.button("Reconstruct slice"):
162
+ with st.spinner("Running SIREN inference..."):
163
+ recon = reconstruct_slice(slice_idx) # (sx, sy, vols)
164
+ img = recon[:, :, vol_idx]
165
+
166
+ # Normalise to [0, 1] for display
167
+ img_norm = (img - img.min()) / (img.max() - img.min() + 1e-8)
168
+
169
+ st.image(
170
+ img_norm,
171
+ caption=f"Reconstructed slice {slice_idx}, volume {vol_idx}",
172
+ use_column_width=True,
173
+ clamp=True,
174
+ )
175
+
176
+ st.success(f"Slice shape: {img.shape} | Value range: [{img.min():.3f}, {img.max():.3f}]")
177
+
178
+ st.divider()
179
+ st.caption(
180
+ f"Model config β€” layers: {config['num_layers']}, "
181
+ f"hidden size: {config['layer_size']}, "
182
+ f"w0: {config['w0']}, slices: {sz}, volumes: {vols}"
183
+ )