medimaging commited on
Commit
41da276
·
verified ·
1 Parent(s): c05f9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -8
app.py CHANGED
@@ -8,6 +8,9 @@ import torch.nn as nn
8
  from math import sqrt
9
  import gradio as gr
10
  import nibabel as nib
 
 
 
11
  from sklearn.preprocessing import MinMaxScaler
12
 
13
  # ══════════════════════════════════════════════════════════════════════════════
@@ -104,6 +107,116 @@ def to_coords(h, w):
104
  gx, gy = torch.meshgrid(xs, ys, indexing="ij")
105
  return torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=-1)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # ═════════════════��════════════════════════════════════════════════════════════
108
  # 4a. Reconstruct from pretrained model
109
  # ══════════════════════════════════════════════════════════════════════════════
@@ -124,7 +237,8 @@ def reconstruct_pretrained(slice_idx, vol_idx):
124
  f"📊 Intensity: [{img_min:.3f}, {img_max:.3f}] | "
125
  f"🧠 Slice {slice_idx} | 📡 Volume {vol_idx}"
126
  )
127
- return to_uint8(recon), stats
 
128
 
129
  # ══════════════════════════════════════════════════════════════════════════════
130
  # 4b. Compress & reconstruct user-uploaded NIfTI
@@ -206,7 +320,9 @@ def compress_and_compare(nifti_file, slice_idx, vol_idx, num_iters, lr):
206
  f"📡 PSNR: {psnr:.2f} dB | "
207
  f"🔁 Iterations: {num_iters}"
208
  )
209
- return orig_img, recon_img, stats
 
 
210
 
211
  # ══════════════════════════════════════════════════════════════════════════════
212
  # 5. Gradio UI
@@ -383,8 +499,7 @@ Adjust the sliders and click **Reconstruct** to visualise any slice and volume.
383
  """)
384
 
385
  with gr.Column(scale=2):
386
- out1 = gr.Image(label="Reconstructed Slice", type="numpy",
387
- elem_id="recon_img", height=420)
388
 
389
  btn1.click(reconstruct_pretrained,
390
  inputs=[sl1, vl1],
@@ -418,10 +533,8 @@ The app will fit a SIREN network to the selected slice on-the-fly and show you
418
 
419
  with gr.Column(scale=2):
420
  with gr.Row():
421
- orig_img = gr.Image(label="📷 Original Slice",
422
- type="numpy", height=380)
423
- recon_img = gr.Image(label="🤖 SIREN Reconstruction",
424
- type="numpy", height=380)
425
 
426
  btn2.click(compress_and_compare,
427
  inputs=[nifti_upload, sl2, vl2, n_iters, lr_inp],
 
8
  from math import sqrt
9
  import gradio as gr
10
  import nibabel as nib
11
+ import base64
12
+ import io
13
+ from PIL import Image
14
  from sklearn.preprocessing import MinMaxScaler
15
 
16
  # ══════════════════════════════════════════════════════════════════════════════
 
107
  gx, gy = torch.meshgrid(xs, ys, indexing="ij")
108
  return torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=-1)
109
 
110
+ # ══════════════════════════════════════════════════════════════════════════════
111
+ # Helper: build zoomable image HTML component
112
+ # ══════════════════════════════════════════════════════════════════════════════
113
+
114
+ def make_zoom_html(arr_uint8, title=""):
115
+ """Convert a uint8 numpy array to a self-contained zoomable HTML viewer."""
116
+ pil_img = Image.fromarray(arr_uint8)
117
+ # upscale small images so they look crisp
118
+ w, h = pil_img.size
119
+ scale = max(1, 400 // max(w, h))
120
+ pil_img = pil_img.resize((w * scale, h * scale), Image.NEAREST)
121
+ buf = io.BytesIO()
122
+ pil_img.save(buf, format="PNG")
123
+ b64 = base64.b64encode(buf.getvalue()).decode()
124
+ html = f"""
125
+ <div style="background:#f8f9ff;border:1.5px solid #ddd6fe;border-radius:14px;
126
+ padding:12px;user-select:none;">
127
+ <div style="font-weight:800;color:#4c1d95;margin-bottom:8px;font-size:.95rem;">
128
+ 🔍 {title} &nbsp;<span style="font-weight:500;color:#6b7280;font-size:.8rem;">
129
+ Scroll to zoom · Drag to pan · Double-click to reset</span>
130
+ </div>
131
+ <div id="zoom-wrap-{hash(b64) & 0xffff}"
132
+ style="overflow:hidden;border-radius:10px;background:#000;
133
+ width:100%;height:420px;cursor:grab;position:relative;">
134
+ <img id="zoom-img-{hash(b64) & 0xffff}"
135
+ src="data:image/png;base64,{b64}"
136
+ style="transform-origin:0 0;transform:scale(1) translate(0px,0px);
137
+ image-rendering:pixelated;max-width:none;
138
+ width:100%;height:100%;object-fit:contain;display:block;"
139
+ draggable="false"/>
140
+ </div>
141
+ </div>
142
+ <script>
143
+ (function() {{
144
+ const wid = '{hash(b64) & 0xffff}';
145
+ const wrap = document.getElementById('zoom-wrap-' + wid);
146
+ const img = document.getElementById('zoom-img-' + wid);
147
+ if (!wrap || !img) return;
148
+
149
+ let scale = 1, ox = 0, oy = 0;
150
+ let dragging = false, startX, startY, lastOx, lastOy;
151
+ const MIN = 0.5, MAX = 12;
152
+
153
+ function apply() {{
154
+ img.style.transform = `scale(${{scale}}) translate(${{ox}}px,${{oy}}px)`;
155
+ }}
156
+
157
+ // Scroll to zoom
158
+ wrap.addEventListener('wheel', e => {{
159
+ e.preventDefault();
160
+ const rect = wrap.getBoundingClientRect();
161
+ const mx = e.clientX - rect.left;
162
+ const my = e.clientY - rect.top;
163
+ const factor = e.deltaY < 0 ? 1.12 : 0.89;
164
+ const newScale = Math.min(MAX, Math.max(MIN, scale * factor));
165
+ ox = mx / newScale - mx / scale + ox;
166
+ oy = my / newScale - my / scale + oy;
167
+ scale = newScale;
168
+ apply();
169
+ }}, {{ passive: false }});
170
+
171
+ // Drag to pan
172
+ wrap.addEventListener('mousedown', e => {{
173
+ dragging = true; wrap.style.cursor = 'grabbing';
174
+ startX = e.clientX; startY = e.clientY;
175
+ lastOx = ox; lastOy = oy;
176
+ }});
177
+ window.addEventListener('mousemove', e => {{
178
+ if (!dragging) return;
179
+ ox = lastOx + (e.clientX - startX) / scale;
180
+ oy = lastOy + (e.clientY - startY) / scale;
181
+ apply();
182
+ }});
183
+ window.addEventListener('mouseup', () => {{
184
+ dragging = false; wrap.style.cursor = 'grab';
185
+ }});
186
+
187
+ // Double-click to reset
188
+ wrap.addEventListener('dblclick', () => {{
189
+ scale = 1; ox = 0; oy = 0; apply();
190
+ }});
191
+
192
+ // Touch support
193
+ let lastDist = null;
194
+ wrap.addEventListener('touchstart', e => {{
195
+ if (e.touches.length === 1) {{
196
+ dragging = true;
197
+ startX = e.touches[0].clientX; startY = e.touches[0].clientY;
198
+ lastOx = ox; lastOy = oy;
199
+ }}
200
+ }}, {{ passive: true }});
201
+ wrap.addEventListener('touchmove', e => {{
202
+ if (e.touches.length === 2) {{
203
+ const d = Math.hypot(
204
+ e.touches[0].clientX - e.touches[1].clientX,
205
+ e.touches[0].clientY - e.touches[1].clientY);
206
+ if (lastDist) {{ scale = Math.min(MAX, Math.max(MIN, scale * d / lastDist)); apply(); }}
207
+ lastDist = d;
208
+ }} else if (e.touches.length === 1 && dragging) {{
209
+ ox = lastOx + (e.touches[0].clientX - startX) / scale;
210
+ oy = lastOy + (e.touches[0].clientY - startY) / scale;
211
+ apply();
212
+ }}
213
+ }}, {{ passive: true }});
214
+ wrap.addEventListener('touchend', () => {{ dragging = false; lastDist = null; }});
215
+ }})();
216
+ </script>
217
+ """
218
+ return html
219
+
220
  # ═════════════════��════════════════════════════════════════════════════════════
221
  # 4a. Reconstruct from pretrained model
222
  # ══════════════════════════════════════════════════════════════════════════════
 
237
  f"📊 Intensity: [{img_min:.3f}, {img_max:.3f}] | "
238
  f"🧠 Slice {slice_idx} | 📡 Volume {vol_idx}"
239
  )
240
+ html = make_zoom_html(to_uint8(recon), f"Reconstructed — Slice {slice_idx}, Volume {vol_idx}")
241
+ return html, stats
242
 
243
  # ══════════════════════════════════════════════════════════════════════════════
244
  # 4b. Compress & reconstruct user-uploaded NIfTI
 
320
  f"📡 PSNR: {psnr:.2f} dB | "
321
  f"🔁 Iterations: {num_iters}"
322
  )
323
+ orig_html = make_zoom_html(orig_img, f"Original — Slice {slice_idx}, Volume {vol_idx}")
324
+ recon_html = make_zoom_html(recon_img, f"SIREN Reconstruction — Slice {slice_idx}, Volume {vol_idx}")
325
+ return orig_html, recon_html, stats
326
 
327
  # ══════════════════════════════════════════════════════════════════════════════
328
  # 5. Gradio UI
 
499
  """)
500
 
501
  with gr.Column(scale=2):
502
+ out1 = gr.HTML(label="Reconstructed Slice", elem_id="recon_img")
 
503
 
504
  btn1.click(reconstruct_pretrained,
505
  inputs=[sl1, vl1],
 
533
 
534
  with gr.Column(scale=2):
535
  with gr.Row():
536
+ orig_img = gr.HTML(label="📷 Original Slice")
537
+ recon_img = gr.HTML(label="🤖 SIREN Reconstruction")
 
 
538
 
539
  btn2.click(compress_and_compare,
540
  inputs=[nifti_upload, sl2, vl2, n_iters, lr_inp],