Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,9 @@ import torch.nn as nn
|
|
| 8 |
from math import sqrt
|
| 9 |
import gradio as gr
|
| 10 |
import nibabel as nib
|
|
|
|
|
|
|
|
|
|
| 11 |
from sklearn.preprocessing import MinMaxScaler
|
| 12 |
|
| 13 |
# ══════════════════════════════════════════════════════════════════════════════
|
|
@@ -104,6 +107,116 @@ def to_coords(h, w):
|
|
| 104 |
gx, gy = torch.meshgrid(xs, ys, indexing="ij")
|
| 105 |
return torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=-1)
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# ═════════════════��════════════════════════════════════════════════════════════
|
| 108 |
# 4a. Reconstruct from pretrained model
|
| 109 |
# ══════════════════════════════════════════════════════════════════════════════
|
|
@@ -124,7 +237,8 @@ def reconstruct_pretrained(slice_idx, vol_idx):
|
|
| 124 |
f"📊 Intensity: [{img_min:.3f}, {img_max:.3f}] | "
|
| 125 |
f"🧠 Slice {slice_idx} | 📡 Volume {vol_idx}"
|
| 126 |
)
|
| 127 |
-
|
|
|
|
| 128 |
|
| 129 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 130 |
# 4b. Compress & reconstruct user-uploaded NIfTI
|
|
@@ -206,7 +320,9 @@ def compress_and_compare(nifti_file, slice_idx, vol_idx, num_iters, lr):
|
|
| 206 |
f"📡 PSNR: {psnr:.2f} dB | "
|
| 207 |
f"🔁 Iterations: {num_iters}"
|
| 208 |
)
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
|
| 211 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 212 |
# 5. Gradio UI
|
|
@@ -383,8 +499,7 @@ Adjust the sliders and click **Reconstruct** to visualise any slice and volume.
|
|
| 383 |
""")
|
| 384 |
|
| 385 |
with gr.Column(scale=2):
|
| 386 |
-
out1 = gr.
|
| 387 |
-
elem_id="recon_img", height=420)
|
| 388 |
|
| 389 |
btn1.click(reconstruct_pretrained,
|
| 390 |
inputs=[sl1, vl1],
|
|
@@ -418,10 +533,8 @@ The app will fit a SIREN network to the selected slice on-the-fly and show you
|
|
| 418 |
|
| 419 |
with gr.Column(scale=2):
|
| 420 |
with gr.Row():
|
| 421 |
-
orig_img = gr.
|
| 422 |
-
|
| 423 |
-
recon_img = gr.Image(label="🤖 SIREN Reconstruction",
|
| 424 |
-
type="numpy", height=380)
|
| 425 |
|
| 426 |
btn2.click(compress_and_compare,
|
| 427 |
inputs=[nifti_upload, sl2, vl2, n_iters, lr_inp],
|
|
|
|
| 8 |
from math import sqrt
|
| 9 |
import gradio as gr
|
| 10 |
import nibabel as nib
|
| 11 |
+
import base64
|
| 12 |
+
import io
|
| 13 |
+
from PIL import Image
|
| 14 |
from sklearn.preprocessing import MinMaxScaler
|
| 15 |
|
| 16 |
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
| 107 |
gx, gy = torch.meshgrid(xs, ys, indexing="ij")
|
| 108 |
return torch.stack([gx.reshape(-1), gy.reshape(-1)], dim=-1)
|
| 109 |
|
| 110 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 111 |
+
# Helper: build zoomable image HTML component
|
| 112 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 113 |
+
|
| 114 |
+
def make_zoom_html(arr_uint8, title=""):
|
| 115 |
+
"""Convert a uint8 numpy array to a self-contained zoomable HTML viewer."""
|
| 116 |
+
pil_img = Image.fromarray(arr_uint8)
|
| 117 |
+
# upscale small images so they look crisp
|
| 118 |
+
w, h = pil_img.size
|
| 119 |
+
scale = max(1, 400 // max(w, h))
|
| 120 |
+
pil_img = pil_img.resize((w * scale, h * scale), Image.NEAREST)
|
| 121 |
+
buf = io.BytesIO()
|
| 122 |
+
pil_img.save(buf, format="PNG")
|
| 123 |
+
b64 = base64.b64encode(buf.getvalue()).decode()
|
| 124 |
+
html = f"""
|
| 125 |
+
<div style="background:#f8f9ff;border:1.5px solid #ddd6fe;border-radius:14px;
|
| 126 |
+
padding:12px;user-select:none;">
|
| 127 |
+
<div style="font-weight:800;color:#4c1d95;margin-bottom:8px;font-size:.95rem;">
|
| 128 |
+
🔍 {title} <span style="font-weight:500;color:#6b7280;font-size:.8rem;">
|
| 129 |
+
Scroll to zoom · Drag to pan · Double-click to reset</span>
|
| 130 |
+
</div>
|
| 131 |
+
<div id="zoom-wrap-{hash(b64) & 0xffff}"
|
| 132 |
+
style="overflow:hidden;border-radius:10px;background:#000;
|
| 133 |
+
width:100%;height:420px;cursor:grab;position:relative;">
|
| 134 |
+
<img id="zoom-img-{hash(b64) & 0xffff}"
|
| 135 |
+
src="data:image/png;base64,{b64}"
|
| 136 |
+
style="transform-origin:0 0;transform:scale(1) translate(0px,0px);
|
| 137 |
+
image-rendering:pixelated;max-width:none;
|
| 138 |
+
width:100%;height:100%;object-fit:contain;display:block;"
|
| 139 |
+
draggable="false"/>
|
| 140 |
+
</div>
|
| 141 |
+
</div>
|
| 142 |
+
<script>
|
| 143 |
+
(function() {{
|
| 144 |
+
const wid = '{hash(b64) & 0xffff}';
|
| 145 |
+
const wrap = document.getElementById('zoom-wrap-' + wid);
|
| 146 |
+
const img = document.getElementById('zoom-img-' + wid);
|
| 147 |
+
if (!wrap || !img) return;
|
| 148 |
+
|
| 149 |
+
let scale = 1, ox = 0, oy = 0;
|
| 150 |
+
let dragging = false, startX, startY, lastOx, lastOy;
|
| 151 |
+
const MIN = 0.5, MAX = 12;
|
| 152 |
+
|
| 153 |
+
function apply() {{
|
| 154 |
+
img.style.transform = `scale(${{scale}}) translate(${{ox}}px,${{oy}}px)`;
|
| 155 |
+
}}
|
| 156 |
+
|
| 157 |
+
// Scroll to zoom
|
| 158 |
+
wrap.addEventListener('wheel', e => {{
|
| 159 |
+
e.preventDefault();
|
| 160 |
+
const rect = wrap.getBoundingClientRect();
|
| 161 |
+
const mx = e.clientX - rect.left;
|
| 162 |
+
const my = e.clientY - rect.top;
|
| 163 |
+
const factor = e.deltaY < 0 ? 1.12 : 0.89;
|
| 164 |
+
const newScale = Math.min(MAX, Math.max(MIN, scale * factor));
|
| 165 |
+
ox = mx / newScale - mx / scale + ox;
|
| 166 |
+
oy = my / newScale - my / scale + oy;
|
| 167 |
+
scale = newScale;
|
| 168 |
+
apply();
|
| 169 |
+
}}, {{ passive: false }});
|
| 170 |
+
|
| 171 |
+
// Drag to pan
|
| 172 |
+
wrap.addEventListener('mousedown', e => {{
|
| 173 |
+
dragging = true; wrap.style.cursor = 'grabbing';
|
| 174 |
+
startX = e.clientX; startY = e.clientY;
|
| 175 |
+
lastOx = ox; lastOy = oy;
|
| 176 |
+
}});
|
| 177 |
+
window.addEventListener('mousemove', e => {{
|
| 178 |
+
if (!dragging) return;
|
| 179 |
+
ox = lastOx + (e.clientX - startX) / scale;
|
| 180 |
+
oy = lastOy + (e.clientY - startY) / scale;
|
| 181 |
+
apply();
|
| 182 |
+
}});
|
| 183 |
+
window.addEventListener('mouseup', () => {{
|
| 184 |
+
dragging = false; wrap.style.cursor = 'grab';
|
| 185 |
+
}});
|
| 186 |
+
|
| 187 |
+
// Double-click to reset
|
| 188 |
+
wrap.addEventListener('dblclick', () => {{
|
| 189 |
+
scale = 1; ox = 0; oy = 0; apply();
|
| 190 |
+
}});
|
| 191 |
+
|
| 192 |
+
// Touch support
|
| 193 |
+
let lastDist = null;
|
| 194 |
+
wrap.addEventListener('touchstart', e => {{
|
| 195 |
+
if (e.touches.length === 1) {{
|
| 196 |
+
dragging = true;
|
| 197 |
+
startX = e.touches[0].clientX; startY = e.touches[0].clientY;
|
| 198 |
+
lastOx = ox; lastOy = oy;
|
| 199 |
+
}}
|
| 200 |
+
}}, {{ passive: true }});
|
| 201 |
+
wrap.addEventListener('touchmove', e => {{
|
| 202 |
+
if (e.touches.length === 2) {{
|
| 203 |
+
const d = Math.hypot(
|
| 204 |
+
e.touches[0].clientX - e.touches[1].clientX,
|
| 205 |
+
e.touches[0].clientY - e.touches[1].clientY);
|
| 206 |
+
if (lastDist) {{ scale = Math.min(MAX, Math.max(MIN, scale * d / lastDist)); apply(); }}
|
| 207 |
+
lastDist = d;
|
| 208 |
+
}} else if (e.touches.length === 1 && dragging) {{
|
| 209 |
+
ox = lastOx + (e.touches[0].clientX - startX) / scale;
|
| 210 |
+
oy = lastOy + (e.touches[0].clientY - startY) / scale;
|
| 211 |
+
apply();
|
| 212 |
+
}}
|
| 213 |
+
}}, {{ passive: true }});
|
| 214 |
+
wrap.addEventListener('touchend', () => {{ dragging = false; lastDist = null; }});
|
| 215 |
+
}})();
|
| 216 |
+
</script>
|
| 217 |
+
"""
|
| 218 |
+
return html
|
| 219 |
+
|
| 220 |
# ═════════════════��════════════════════════════════════════════════════════════
|
| 221 |
# 4a. Reconstruct from pretrained model
|
| 222 |
# ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
| 237 |
f"📊 Intensity: [{img_min:.3f}, {img_max:.3f}] | "
|
| 238 |
f"🧠 Slice {slice_idx} | 📡 Volume {vol_idx}"
|
| 239 |
)
|
| 240 |
+
html = make_zoom_html(to_uint8(recon), f"Reconstructed — Slice {slice_idx}, Volume {vol_idx}")
|
| 241 |
+
return html, stats
|
| 242 |
|
| 243 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 244 |
# 4b. Compress & reconstruct user-uploaded NIfTI
|
|
|
|
| 320 |
f"📡 PSNR: {psnr:.2f} dB | "
|
| 321 |
f"🔁 Iterations: {num_iters}"
|
| 322 |
)
|
| 323 |
+
orig_html = make_zoom_html(orig_img, f"Original — Slice {slice_idx}, Volume {vol_idx}")
|
| 324 |
+
recon_html = make_zoom_html(recon_img, f"SIREN Reconstruction — Slice {slice_idx}, Volume {vol_idx}")
|
| 325 |
+
return orig_html, recon_html, stats
|
| 326 |
|
| 327 |
# ══════════════════════════════════════════════════════════════════════════════
|
| 328 |
# 5. Gradio UI
|
|
|
|
| 499 |
""")
|
| 500 |
|
| 501 |
with gr.Column(scale=2):
|
| 502 |
+
out1 = gr.HTML(label="Reconstructed Slice", elem_id="recon_img")
|
|
|
|
| 503 |
|
| 504 |
btn1.click(reconstruct_pretrained,
|
| 505 |
inputs=[sl1, vl1],
|
|
|
|
| 533 |
|
| 534 |
with gr.Column(scale=2):
|
| 535 |
with gr.Row():
|
| 536 |
+
orig_img = gr.HTML(label="📷 Original Slice")
|
| 537 |
+
recon_img = gr.HTML(label="🤖 SIREN Reconstruction")
|
|
|
|
|
|
|
| 538 |
|
| 539 |
btn2.click(compress_and_compare,
|
| 540 |
inputs=[nifti_upload, sl2, vl2, n_iters, lr_inp],
|