ct-windowing-quantization-demo / src /streamlit_app.py
HuggingKatze's picture
Restore local app and sync recovery version
3680186
Raw
History Blame Contribute Delete
22.7 kB
from __future__ import annotations
from io import BytesIO
from pathlib import Path
import base64
import tempfile
from typing import Any
import numpy as np
from PIL import Image, UnidentifiedImageError
import streamlit as st
import streamlit.components.v1 as components
try:
import nibabel as nib
except Exception: # pragma: no cover
nib = None
try:
import pydicom
except Exception: # pragma: no cover
pydicom = None
APP_TITLE = "Medical Quantization and CT Windowing Explorer"
SPACE_URL_PLACEHOLDER = "https://huggingface.co/spaces/HuggingKatze/ct-windowing-quantization-demo"
SAMPLE_DIR = Path(__file__).resolve().parents[1] / "sample_images" / "builtin_samples"
COMPARE_OPTIONS = [
"Original vs Windowed",
"Original vs Quantized",
"Windowed vs Quantized + Windowed",
"Quantized vs Quantized + Windowed",
"Original vs Quantized + Windowed",
"Windowed vs Quantized",
]
WINDOW_PRESETS = {
"Custom": None,
"Lung": (-1024.0, 150.0),
"Mediastinum": (-160.0, 240.0),
"Bone": (300.0, 2000.0),
"Soft tissue": (-125.0, 225.0),
}
BUILTIN_SAMPLES = {
"CT-RATE chest CT volume": SAMPLE_DIR / "ct_rate_valid_100_a_1.npz",
"LDCT full-dose chest CT": SAMPLE_DIR / "ldct_full_dose_1-135.npz",
"RSNA PE Detection CTPA slice": SAMPLE_DIR / "rsna_pe_7bf959bb5c7e.npz",
}
TITLE_MAP = {
"original": "Original",
"windowed": "Windowed",
"quantized": "Quantized",
"quantized_windowed": "Quantized + Windowed",
}
CARD_X = {"left": 70, "right": 560}
CARD_Y = {"top": 40, "bottom": 560}
CARD_W = 360
CARD_H = 360
SVG_W = 1120
SVG_H = 975
ARROW_MARGIN = 26
def safe_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return default
def safe_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return default
def normalize_01(image: np.ndarray) -> np.ndarray:
arr = image.astype(np.float32)
if arr.size == 0:
return np.zeros((16, 16), dtype=np.float32)
lo = float(np.min(arr))
hi = float(np.max(arr))
if hi <= lo:
return np.zeros_like(arr, dtype=np.float32)
return (arr - lo) / (hi - lo)
def resize_for_display(image: np.ndarray, max_side: int = 320) -> np.ndarray:
arr = image.astype(np.float32)
h, w = arr.shape
scale = min(max_side / max(h, 1), max_side / max(w, 1), 1.0)
if scale >= 1.0:
return arr
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
norm = normalize_01(arr)
pil = Image.fromarray((norm * 255).astype(np.uint8), mode="L")
resized = pil.resize((new_w, new_h), Image.Resampling.BILINEAR)
out = np.asarray(resized, dtype=np.float32) / 255.0
lo = float(np.min(arr))
hi = float(np.max(arr))
return out * (hi - lo) + lo
def to_png_base64(image: np.ndarray) -> str:
norm = normalize_01(image)
pil = Image.fromarray((norm * 255).clip(0, 255).astype(np.uint8), mode="L")
buf = BytesIO()
pil.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def quantize_image(image: np.ndarray, bit_depth: int) -> np.ndarray:
arr = image.astype(np.float32)
lo = float(np.min(arr))
hi = float(np.max(arr))
if hi <= lo:
return arr.copy()
levels = max(2, 2 ** int(bit_depth))
norm = (arr - lo) / (hi - lo)
quantized = np.round(norm * (levels - 1)) / (levels - 1)
return quantized * (hi - lo) + lo
def apply_window(image: np.ndarray, low: float, high: float) -> np.ndarray:
arr = np.clip(image.astype(np.float32), low, high)
if high <= low:
return np.zeros_like(arr, dtype=np.float32)
return (arr - low) / (high - low)
def compute_entropy(image: np.ndarray, bins: int = 256) -> float:
arr = image.astype(np.float32)
if arr.size == 0:
return 0.0
lo = float(np.min(arr))
hi = float(np.max(arr))
if hi <= lo:
return 0.0
hist, _ = np.histogram(arr, bins=bins, range=(lo, hi))
prob = hist.astype(np.float64)
prob = prob / np.sum(prob)
prob = prob[prob > 0]
return float(-(prob * np.log2(prob)).sum())
def gray_levels(image: np.ndarray) -> int:
arr = image.astype(np.float32)
if arr.size == 0:
return 0
norm = np.round(normalize_01(arr) * 255).astype(np.uint8)
return int(np.unique(norm).size)
def psnr(left: np.ndarray, right: np.ndarray) -> float:
mse = float(np.mean((left.astype(np.float32) - right.astype(np.float32)) ** 2))
if mse <= 1e-12:
return float("inf")
return float(20.0 * np.log10(1.0 / np.sqrt(mse)))
def approx_ssim(left: np.ndarray, right: np.ndarray) -> float:
x = left.astype(np.float64)
y = right.astype(np.float64)
c1 = 0.01 ** 2
c2 = 0.03 ** 2
mx = float(np.mean(x))
my = float(np.mean(y))
vx = float(np.var(x))
vy = float(np.var(y))
cxy = float(np.mean((x - mx) * (y - my)))
denom = (mx * mx + my * my + c1) * (vx + vy + c2)
if denom == 0:
return 1.0
return float(((2 * mx * my + c1) * (2 * cxy + c2)) / denom)
def compare_metrics(left: np.ndarray, right: np.ndarray) -> list[tuple[str, str]]:
diff = left.astype(np.float32) - right.astype(np.float32)
mse = float(np.mean(diff ** 2))
mae = float(np.mean(np.abs(diff)))
rmse = float(np.sqrt(mse))
p = psnr(left, right)
return [
("MSE", f"{mse:.6f}"),
("MAE", f"{mae:.6f}"),
("RMSE", f"{rmse:.6f}"),
("PSNR", "inf" if np.isinf(p) else f"{p:.3f} dB"),
("SSIM", f"{approx_ssim(left, right):.4f}"),
("Gray levels", f"{gray_levels(left)} / {gray_levels(right)}"),
("Entropy", f"{compute_entropy(left):.4f} / {compute_entropy(right):.4f}"),
]
def parse_npz(source: Any) -> dict[str, Any]:
data = np.load(source, allow_pickle=True)
return {k: data[k] for k in data.files}
def parse_builtin_payload(payload: dict[str, Any], label: str) -> dict[str, Any]:
image = payload.get("image")
volume = payload.get("volume")
bit_depth = int(np.atleast_1d(payload.get("original_bit_depth", [16]))[0])
note = str(np.atleast_1d(payload.get("note", [""]))[0])
suggested_window = None
if "suggested_window" in payload:
suggested_window = tuple(map(float, np.atleast_1d(payload["suggested_window"])[:2]))
elif image is not None:
suggested_window = (-1024.0, 150.0)
if volume is not None and getattr(volume, "ndim", 0) == 3:
slice_index = int(np.atleast_1d(payload.get("default_slice_index", [volume.shape[0] // 2]))[0])
return {
"label": label,
"image": volume[slice_index].astype(np.float32),
"volume": volume.astype(np.float32),
"slice_index": slice_index,
"bit_depth": bit_depth,
"suggested_window": suggested_window,
"note": note,
}
if image is None:
raise ValueError("Sample file does not contain a readable image array.")
return {
"label": label,
"image": image.astype(np.float32),
"volume": None,
"slice_index": None,
"bit_depth": bit_depth,
"suggested_window": suggested_window,
"note": note,
}
def load_builtin_sample(label: str) -> dict[str, Any]:
path = BUILTIN_SAMPLES.get(label)
if path is None or not path.exists():
raise ValueError(f"Built-in sample not found: {label}")
return parse_builtin_payload(parse_npz(path), label)
def load_nifti(raw_bytes: bytes, name: str) -> dict[str, Any]:
if nib is None:
raise ValueError("NIfTI upload requires nibabel, which is not installed.")
with tempfile.NamedTemporaryFile(suffix="".join(Path(name).suffixes), delete=False) as tmp:
tmp.write(raw_bytes)
tmp_path = Path(tmp.name)
try:
img = nib.load(str(tmp_path))
vol = np.asarray(img.get_fdata(dtype=np.float32))
finally:
try:
tmp_path.unlink(missing_ok=True)
except Exception:
pass
if vol.ndim != 3 or 0 in vol.shape:
raise ValueError("Expected a non-empty 3D NIfTI volume.")
axis0 = int(np.argmin(vol.shape))
if axis0 != 0:
vol = np.moveaxis(vol, axis0, 0)
slice_index = vol.shape[0] // 2
return {
"label": name,
"image": vol[slice_index].astype(np.float32),
"volume": vol.astype(np.float32),
"slice_index": slice_index,
"bit_depth": 16,
"suggested_window": (-1024.0, 150.0),
"note": "Uploaded NIfTI volume.",
}
def load_dicom(raw_bytes: bytes, name: str) -> dict[str, Any]:
if pydicom is None:
raise ValueError("DICOM upload requires pydicom, which is not installed.")
ds = pydicom.dcmread(BytesIO(raw_bytes), force=True)
if not hasattr(ds, "PixelData"):
raise ValueError("DICOM file has no readable pixel data.")
image = ds.pixel_array.astype(np.float32)
slope = safe_float(getattr(ds, "RescaleSlope", 1.0), 1.0)
intercept = safe_float(getattr(ds, "RescaleIntercept", 0.0), 0.0)
image = image * slope + intercept
bit_depth = safe_int(getattr(ds, "BitsStored", getattr(ds, "BitsAllocated", 16)), 16)
return {
"label": name,
"image": image,
"volume": None,
"slice_index": None,
"bit_depth": bit_depth,
"suggested_window": (-1024.0, 150.0),
"note": "Uploaded DICOM slice.",
}
def load_regular_image(raw_bytes: bytes, name: str) -> dict[str, Any]:
try:
image = Image.open(BytesIO(raw_bytes)).convert("L")
except UnidentifiedImageError as exc:
raise ValueError("Could not read this file. Please upload PNG, JPG, DICOM, NIfTI, or NPZ.") from exc
arr = np.asarray(image).astype(np.float32)
return {
"label": name,
"image": arr,
"volume": None,
"slice_index": None,
"bit_depth": 8,
"suggested_window": (float(np.min(arr)), float(np.max(arr))) if arr.size else (0.0, 1.0),
"note": "Uploaded grayscale image.",
}
def load_uploaded_file(uploaded) -> dict[str, Any]:
if uploaded is None:
raise ValueError("No file was uploaded.")
raw = uploaded.getvalue()
if not raw:
raise ValueError("Uploaded file is empty.")
suffixes = "".join(Path(uploaded.name).suffixes).lower()
if suffixes.endswith(".npz"):
return parse_builtin_payload(parse_npz(BytesIO(raw)), uploaded.name)
if suffixes.endswith(".nii") or suffixes.endswith(".nii.gz"):
return load_nifti(raw, uploaded.name)
if suffixes.endswith(".dcm"):
return load_dicom(raw, uploaded.name)
return load_regular_image(raw, uploaded.name)
def metric_pair(images: dict[str, np.ndarray], compare: str) -> tuple[np.ndarray, np.ndarray, str, str]:
mapping = {
"Original vs Windowed": ("original", "windowed"),
"Original vs Quantized": ("original", "quantized"),
"Windowed vs Quantized + Windowed": ("windowed", "quantized_windowed"),
"Quantized vs Quantized + Windowed": ("quantized", "quantized_windowed"),
"Original vs Quantized + Windowed": ("original", "quantized_windowed"),
"Windowed vs Quantized": ("windowed", "quantized"),
}
left_key, right_key = mapping[compare]
return images[left_key], images[right_key], TITLE_MAP[left_key], TITLE_MAP[right_key]
def render_visualization_svg(
images: dict[str, np.ndarray],
bit_depth: int,
low: float,
high: float,
selected_compare: str,
) -> str:
defs = """
<defs>
<marker id="arrow-gray" markerWidth="10" markerHeight="10" refX="8" refY="5" orient="auto-start-reverse">
<path d="M0,0 L10,5 L0,10 z" fill="#A4AFBA"></path>
</marker>
<marker id="arrow-red" markerWidth="10" markerHeight="10" refX="8" refY="5" orient="auto-start-reverse">
<path d="M0,0 L10,5 L0,10 z" fill="#C4473A"></path>
</marker>
<style>
.edge-hit { cursor: pointer; }
.title { font-size: 18px; font-weight: 600; fill: #22313F; font-family: sans-serif; }
.meta { font-size: 17px; fill: #22313F; font-family: sans-serif; }
</style>
<script>
function setCompare(value) {
const url = new URL(window.top.location.href);
url.searchParams.set("compare", value);
window.top.location.href = url.toString();
}
</script>
</defs>
"""
cards = {
"original": (CARD_X["left"], CARD_Y["top"]),
"windowed": (CARD_X["right"], CARD_Y["top"]),
"quantized": (CARD_X["left"], CARD_Y["bottom"]),
"quantized_windowed": (CARD_X["right"], CARD_Y["bottom"]),
}
title_base_y = {
"original": 432,
"windowed": 432,
"quantized": 952,
"quantized_windowed": 952,
}
lines = [
("Original vs Windowed", (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["top"] + CARD_H / 2), (CARD_X["right"] - ARROW_MARGIN, CARD_Y["top"] + CARD_H / 2)),
("Original vs Quantized", (CARD_X["left"] + CARD_W / 2, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["left"] + CARD_W / 2, CARD_Y["bottom"] - ARROW_MARGIN)),
("Windowed vs Quantized + Windowed", (CARD_X["right"] + CARD_W / 2, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["right"] + CARD_W / 2, CARD_Y["bottom"] - ARROW_MARGIN)),
("Quantized vs Quantized + Windowed", (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["bottom"] + CARD_H / 2), (CARD_X["right"] - ARROW_MARGIN, CARD_Y["bottom"] + CARD_H / 2)),
("Original vs Quantized + Windowed", (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["right"] - ARROW_MARGIN, CARD_Y["bottom"] - ARROW_MARGIN)),
("Windowed vs Quantized", (CARD_X["right"] - ARROW_MARGIN, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["bottom"] - ARROW_MARGIN)),
]
parts = [f'<svg xmlns="http://www.w3.org/2000/svg" width="100%" viewBox="0 0 {SVG_W} {SVG_H}">', defs]
for key, (x, y) in cards.items():
png = to_png_base64(images[key])
ty = title_base_y[key]
parts.append(f'<rect x="{x - 8}" y="{y - 8}" width="376" height="448" rx="18" fill="#FFFFFF" stroke="#D7E0EA" stroke-width="2"/>')
parts.append(f'<image x="{x}" y="{y}" width="{CARD_W}" height="{CARD_H}" href="data:image/png;base64,{png}" preserveAspectRatio="xMidYMid meet"/>')
parts.append(f'<text class="title" x="{x}" y="{ty}">{TITLE_MAP[key]}</text>')
parts.append(f'<text class="meta" x="{x}" y="{ty + 26}">Quantization: bit depth = {bit_depth}</text>')
parts.append(f'<text class="meta" x="{x}" y="{ty + 52}">Window: [{low:.1f}, {high:.1f}]</text>')
for label, start, end in lines:
active = label == selected_compare
color = "#C4473A" if active else "#A4AFBA"
marker = "arrow-red" if active else "arrow-gray"
sx, sy = start
ex, ey = end
parts.append(
f'<g class="edge-hit" onclick="setCompare(\'{label}\')">'
f'<line x1="{sx}" y1="{sy}" x2="{ex}" y2="{ey}" stroke="{color}" stroke-width="18" stroke-opacity="0.01"/>'
f'<line x1="{sx}" y1="{sy}" x2="{ex}" y2="{ey}" stroke="{color}" stroke-width="4" stroke-linecap="round" marker-start="url(#{marker})" marker-end="url(#{marker})" opacity="0.95"/>'
f"</g>"
)
parts.append("</svg>")
return "".join(parts)
def main() -> None:
st.set_page_config(page_title=APP_TITLE, layout="wide")
st.title(APP_TITLE)
st.caption(
"Interactive demo for grayscale quantization and CT windowing. "
"Built-in samples are bundled with the repo, and you can also upload your own image, DICOM, NIfTI, or NPZ."
)
current_compare = str(st.query_params.get("compare", COMPARE_OPTIONS[0]))
if current_compare not in COMPARE_OPTIONS:
current_compare = COMPARE_OPTIONS[0]
left, right = st.columns([1, 1.7], gap="large")
with left:
st.subheader("Controls")
st.caption("All interactive controls are annotated below. Invalid input should show a friendly error instead of crashing the app.")
input_source = st.radio(
"Input source",
["Built-in sample", "Upload file"],
help="Choose a bundled example from the repository, or upload your own grayscale image / medical scan file.",
)
data = None
try:
if input_source == "Built-in sample":
sample_label = st.selectbox(
"Built-in sample",
list(BUILTIN_SAMPLES.keys()),
help="Repository-safe sample assets for Hugging Face deployment.",
)
data = load_builtin_sample(sample_label)
else:
uploaded = st.file_uploader(
"Upload image / medical file",
type=["png", "jpg", "jpeg", "dcm", "nii", "gz", "npz"],
help="Supported formats: PNG, JPG, DICOM (.dcm), NIfTI (.nii/.nii.gz), and NPZ arrays.",
)
if uploaded is not None:
data = load_uploaded_file(uploaded)
except Exception as exc:
st.error(f"Failed to load input: {exc}")
data = None
if data is None:
st.info("Choose a built-in sample or upload a file to begin.")
return
image = data["image"].astype(np.float32)
volume = data.get("volume")
if volume is not None:
default_slice = int(data.get("slice_index") or volume.shape[0] // 2)
slice_index = st.slider(
"Slice index",
min_value=0,
max_value=int(volume.shape[0] - 1),
value=int(np.clip(default_slice, 0, volume.shape[0] - 1)),
help="For 3D volumes, choose which axial slice is visualized.",
)
image = volume[slice_index].astype(np.float32)
if input_source == "Upload file":
rotation = st.select_slider(
"Rotate uploaded image",
options=[0, 90, 180, 270],
value=0,
help="Rotate only the displayed 2D slice, which is useful when checking uploaded files.",
)
if rotation:
image = np.rot90(image, rotation // 90)
st.caption(f"Loaded source: {data['label']}")
if data.get("note"):
st.caption(str(data["note"]))
st.caption("Large images are automatically resized for responsive interaction.")
bit_depth = st.slider(
"Quantization bit depth",
min_value=1,
max_value=16,
value=min(max(int(data.get("bit_depth", 8)), 1), 16),
help="Lower bit depth reduces the number of available gray levels before display.",
)
preset_default = "Custom"
suggested_window = data.get("suggested_window")
if suggested_window is not None:
for name, bounds in WINDOW_PRESETS.items():
if bounds == tuple(map(float, suggested_window)):
preset_default = name
break
preset = st.selectbox(
"CT window preset",
list(WINDOW_PRESETS.keys()),
index=list(WINDOW_PRESETS.keys()).index(preset_default),
help="Choose a preset medical display window, or switch to Custom to edit the lower and upper bounds directly.",
)
if preset != "Custom":
low_default, high_default = WINDOW_PRESETS[preset]
elif suggested_window is not None:
low_default, high_default = map(float, suggested_window)
else:
low_default, high_default = float(np.min(image)), float(np.max(image))
lower = st.number_input(
"Window lower bound",
value=float(low_default),
help="Intensities below this value are clipped before the displayed normalization step.",
)
upper = st.number_input(
"Window upper bound",
value=float(high_default),
help="Intensities above this value are clipped before the displayed normalization step.",
)
if upper <= lower:
st.error("Window upper bound must be greater than lower bound.")
return
try:
original = resize_for_display(image)
quantized = quantize_image(original, bit_depth)
windowed = apply_window(original, float(lower), float(upper))
quantized_windowed = apply_window(quantized, float(lower), float(upper))
except Exception as exc:
st.error(f"Failed to prepare visualization: {exc}")
return
images = {
"original": original,
"windowed": windowed,
"quantized": quantized,
"quantized_windowed": quantized_windowed,
}
with right:
st.subheader("Visualization")
st.caption("Each double-arrow line is clickable. The selected relationship changes color and also updates the metric comparison below.")
svg = render_visualization_svg(images, bit_depth, float(lower), float(upper), current_compare)
components.html(svg, height=720, scrolling=False)
st.subheader("Metrics")
compare = st.selectbox(
"Compare",
COMPARE_OPTIONS,
index=COMPARE_OPTIONS.index(current_compare),
help="Choose which transformation pair to compare numerically. This stays synchronized with the selected arrow relation.",
)
if compare != current_compare:
st.query_params["compare"] = compare
st.rerun()
left_img, right_img, left_name, right_name = metric_pair(images, compare)
st.caption(f"Current comparison: {left_name} vs {right_name}")
metrics = compare_metrics(left_img, right_img)
col_a, col_b = st.columns(2)
split = (len(metrics) + 1) // 2
with col_a:
for name, value in metrics[:split]:
st.markdown(f"**{name}** \n{value}")
with col_b:
for name, value in metrics[split:]:
st.markdown(f"**{name}** \n{value}")
with st.expander("Deployment note"):
st.write(
"This Hugging Face Space is public at "
f"[{SPACE_URL_PLACEHOLDER}]({SPACE_URL_PLACEHOLDER}). "
"Built-in samples are bundled as NPZ assets to avoid raw binary medical files in the repo, "
"while external uploads still accept DICOM and NIfTI."
)
if __name__ == "__main__":
main()