from __future__ import annotations from io import BytesIO from pathlib import Path import base64 import tempfile from typing import Any import numpy as np from PIL import Image, UnidentifiedImageError import streamlit as st import streamlit.components.v1 as components try: import nibabel as nib except Exception: # pragma: no cover nib = None try: import pydicom except Exception: # pragma: no cover pydicom = None APP_TITLE = "Medical Quantization and CT Windowing Explorer" SPACE_URL_PLACEHOLDER = "https://huggingface.co/spaces/HuggingKatze/ct-windowing-quantization-demo" SAMPLE_DIR = Path(__file__).resolve().parents[1] / "sample_images" / "builtin_samples" COMPARE_OPTIONS = [ "Original vs Windowed", "Original vs Quantized", "Windowed vs Quantized + Windowed", "Quantized vs Quantized + Windowed", "Original vs Quantized + Windowed", "Windowed vs Quantized", ] WINDOW_PRESETS = { "Custom": None, "Lung": (-1024.0, 150.0), "Mediastinum": (-160.0, 240.0), "Bone": (300.0, 2000.0), "Soft tissue": (-125.0, 225.0), } BUILTIN_SAMPLES = { "CT-RATE chest CT volume": SAMPLE_DIR / "ct_rate_valid_100_a_1.npz", "LDCT full-dose chest CT": SAMPLE_DIR / "ldct_full_dose_1-135.npz", "RSNA PE Detection CTPA slice": SAMPLE_DIR / "rsna_pe_7bf959bb5c7e.npz", } TITLE_MAP = { "original": "Original", "windowed": "Windowed", "quantized": "Quantized", "quantized_windowed": "Quantized + Windowed", } CARD_X = {"left": 70, "right": 560} CARD_Y = {"top": 40, "bottom": 560} CARD_W = 360 CARD_H = 360 SVG_W = 1120 SVG_H = 975 ARROW_MARGIN = 26 def safe_float(value: Any, default: float = 0.0) -> float: try: return float(value) except Exception: return default def safe_int(value: Any, default: int = 0) -> int: try: return int(value) except Exception: return default def normalize_01(image: np.ndarray) -> np.ndarray: arr = image.astype(np.float32) if arr.size == 0: return np.zeros((16, 16), dtype=np.float32) lo = float(np.min(arr)) hi = float(np.max(arr)) if hi <= lo: return np.zeros_like(arr, dtype=np.float32) return (arr - lo) / (hi - lo) def resize_for_display(image: np.ndarray, max_side: int = 320) -> np.ndarray: arr = image.astype(np.float32) h, w = arr.shape scale = min(max_side / max(h, 1), max_side / max(w, 1), 1.0) if scale >= 1.0: return arr new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) norm = normalize_01(arr) pil = Image.fromarray((norm * 255).astype(np.uint8), mode="L") resized = pil.resize((new_w, new_h), Image.Resampling.BILINEAR) out = np.asarray(resized, dtype=np.float32) / 255.0 lo = float(np.min(arr)) hi = float(np.max(arr)) return out * (hi - lo) + lo def to_png_base64(image: np.ndarray) -> str: norm = normalize_01(image) pil = Image.fromarray((norm * 255).clip(0, 255).astype(np.uint8), mode="L") buf = BytesIO() pil.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("ascii") def quantize_image(image: np.ndarray, bit_depth: int) -> np.ndarray: arr = image.astype(np.float32) lo = float(np.min(arr)) hi = float(np.max(arr)) if hi <= lo: return arr.copy() levels = max(2, 2 ** int(bit_depth)) norm = (arr - lo) / (hi - lo) quantized = np.round(norm * (levels - 1)) / (levels - 1) return quantized * (hi - lo) + lo def apply_window(image: np.ndarray, low: float, high: float) -> np.ndarray: arr = np.clip(image.astype(np.float32), low, high) if high <= low: return np.zeros_like(arr, dtype=np.float32) return (arr - low) / (high - low) def compute_entropy(image: np.ndarray, bins: int = 256) -> float: arr = image.astype(np.float32) if arr.size == 0: return 0.0 lo = float(np.min(arr)) hi = float(np.max(arr)) if hi <= lo: return 0.0 hist, _ = np.histogram(arr, bins=bins, range=(lo, hi)) prob = hist.astype(np.float64) prob = prob / np.sum(prob) prob = prob[prob > 0] return float(-(prob * np.log2(prob)).sum()) def gray_levels(image: np.ndarray) -> int: arr = image.astype(np.float32) if arr.size == 0: return 0 norm = np.round(normalize_01(arr) * 255).astype(np.uint8) return int(np.unique(norm).size) def psnr(left: np.ndarray, right: np.ndarray) -> float: mse = float(np.mean((left.astype(np.float32) - right.astype(np.float32)) ** 2)) if mse <= 1e-12: return float("inf") return float(20.0 * np.log10(1.0 / np.sqrt(mse))) def approx_ssim(left: np.ndarray, right: np.ndarray) -> float: x = left.astype(np.float64) y = right.astype(np.float64) c1 = 0.01 ** 2 c2 = 0.03 ** 2 mx = float(np.mean(x)) my = float(np.mean(y)) vx = float(np.var(x)) vy = float(np.var(y)) cxy = float(np.mean((x - mx) * (y - my))) denom = (mx * mx + my * my + c1) * (vx + vy + c2) if denom == 0: return 1.0 return float(((2 * mx * my + c1) * (2 * cxy + c2)) / denom) def compare_metrics(left: np.ndarray, right: np.ndarray) -> list[tuple[str, str]]: diff = left.astype(np.float32) - right.astype(np.float32) mse = float(np.mean(diff ** 2)) mae = float(np.mean(np.abs(diff))) rmse = float(np.sqrt(mse)) p = psnr(left, right) return [ ("MSE", f"{mse:.6f}"), ("MAE", f"{mae:.6f}"), ("RMSE", f"{rmse:.6f}"), ("PSNR", "inf" if np.isinf(p) else f"{p:.3f} dB"), ("SSIM", f"{approx_ssim(left, right):.4f}"), ("Gray levels", f"{gray_levels(left)} / {gray_levels(right)}"), ("Entropy", f"{compute_entropy(left):.4f} / {compute_entropy(right):.4f}"), ] def parse_npz(source: Any) -> dict[str, Any]: data = np.load(source, allow_pickle=True) return {k: data[k] for k in data.files} def parse_builtin_payload(payload: dict[str, Any], label: str) -> dict[str, Any]: image = payload.get("image") volume = payload.get("volume") bit_depth = int(np.atleast_1d(payload.get("original_bit_depth", [16]))[0]) note = str(np.atleast_1d(payload.get("note", [""]))[0]) suggested_window = None if "suggested_window" in payload: suggested_window = tuple(map(float, np.atleast_1d(payload["suggested_window"])[:2])) elif image is not None: suggested_window = (-1024.0, 150.0) if volume is not None and getattr(volume, "ndim", 0) == 3: slice_index = int(np.atleast_1d(payload.get("default_slice_index", [volume.shape[0] // 2]))[0]) return { "label": label, "image": volume[slice_index].astype(np.float32), "volume": volume.astype(np.float32), "slice_index": slice_index, "bit_depth": bit_depth, "suggested_window": suggested_window, "note": note, } if image is None: raise ValueError("Sample file does not contain a readable image array.") return { "label": label, "image": image.astype(np.float32), "volume": None, "slice_index": None, "bit_depth": bit_depth, "suggested_window": suggested_window, "note": note, } def load_builtin_sample(label: str) -> dict[str, Any]: path = BUILTIN_SAMPLES.get(label) if path is None or not path.exists(): raise ValueError(f"Built-in sample not found: {label}") return parse_builtin_payload(parse_npz(path), label) def load_nifti(raw_bytes: bytes, name: str) -> dict[str, Any]: if nib is None: raise ValueError("NIfTI upload requires nibabel, which is not installed.") with tempfile.NamedTemporaryFile(suffix="".join(Path(name).suffixes), delete=False) as tmp: tmp.write(raw_bytes) tmp_path = Path(tmp.name) try: img = nib.load(str(tmp_path)) vol = np.asarray(img.get_fdata(dtype=np.float32)) finally: try: tmp_path.unlink(missing_ok=True) except Exception: pass if vol.ndim != 3 or 0 in vol.shape: raise ValueError("Expected a non-empty 3D NIfTI volume.") axis0 = int(np.argmin(vol.shape)) if axis0 != 0: vol = np.moveaxis(vol, axis0, 0) slice_index = vol.shape[0] // 2 return { "label": name, "image": vol[slice_index].astype(np.float32), "volume": vol.astype(np.float32), "slice_index": slice_index, "bit_depth": 16, "suggested_window": (-1024.0, 150.0), "note": "Uploaded NIfTI volume.", } def load_dicom(raw_bytes: bytes, name: str) -> dict[str, Any]: if pydicom is None: raise ValueError("DICOM upload requires pydicom, which is not installed.") ds = pydicom.dcmread(BytesIO(raw_bytes), force=True) if not hasattr(ds, "PixelData"): raise ValueError("DICOM file has no readable pixel data.") image = ds.pixel_array.astype(np.float32) slope = safe_float(getattr(ds, "RescaleSlope", 1.0), 1.0) intercept = safe_float(getattr(ds, "RescaleIntercept", 0.0), 0.0) image = image * slope + intercept bit_depth = safe_int(getattr(ds, "BitsStored", getattr(ds, "BitsAllocated", 16)), 16) return { "label": name, "image": image, "volume": None, "slice_index": None, "bit_depth": bit_depth, "suggested_window": (-1024.0, 150.0), "note": "Uploaded DICOM slice.", } def load_regular_image(raw_bytes: bytes, name: str) -> dict[str, Any]: try: image = Image.open(BytesIO(raw_bytes)).convert("L") except UnidentifiedImageError as exc: raise ValueError("Could not read this file. Please upload PNG, JPG, DICOM, NIfTI, or NPZ.") from exc arr = np.asarray(image).astype(np.float32) return { "label": name, "image": arr, "volume": None, "slice_index": None, "bit_depth": 8, "suggested_window": (float(np.min(arr)), float(np.max(arr))) if arr.size else (0.0, 1.0), "note": "Uploaded grayscale image.", } def load_uploaded_file(uploaded) -> dict[str, Any]: if uploaded is None: raise ValueError("No file was uploaded.") raw = uploaded.getvalue() if not raw: raise ValueError("Uploaded file is empty.") suffixes = "".join(Path(uploaded.name).suffixes).lower() if suffixes.endswith(".npz"): return parse_builtin_payload(parse_npz(BytesIO(raw)), uploaded.name) if suffixes.endswith(".nii") or suffixes.endswith(".nii.gz"): return load_nifti(raw, uploaded.name) if suffixes.endswith(".dcm"): return load_dicom(raw, uploaded.name) return load_regular_image(raw, uploaded.name) def metric_pair(images: dict[str, np.ndarray], compare: str) -> tuple[np.ndarray, np.ndarray, str, str]: mapping = { "Original vs Windowed": ("original", "windowed"), "Original vs Quantized": ("original", "quantized"), "Windowed vs Quantized + Windowed": ("windowed", "quantized_windowed"), "Quantized vs Quantized + Windowed": ("quantized", "quantized_windowed"), "Original vs Quantized + Windowed": ("original", "quantized_windowed"), "Windowed vs Quantized": ("windowed", "quantized"), } left_key, right_key = mapping[compare] return images[left_key], images[right_key], TITLE_MAP[left_key], TITLE_MAP[right_key] def render_visualization_svg( images: dict[str, np.ndarray], bit_depth: int, low: float, high: float, selected_compare: str, ) -> str: defs = """ """ cards = { "original": (CARD_X["left"], CARD_Y["top"]), "windowed": (CARD_X["right"], CARD_Y["top"]), "quantized": (CARD_X["left"], CARD_Y["bottom"]), "quantized_windowed": (CARD_X["right"], CARD_Y["bottom"]), } title_base_y = { "original": 432, "windowed": 432, "quantized": 952, "quantized_windowed": 952, } lines = [ ("Original vs Windowed", (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["top"] + CARD_H / 2), (CARD_X["right"] - ARROW_MARGIN, CARD_Y["top"] + CARD_H / 2)), ("Original vs Quantized", (CARD_X["left"] + CARD_W / 2, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["left"] + CARD_W / 2, CARD_Y["bottom"] - ARROW_MARGIN)), ("Windowed vs Quantized + Windowed", (CARD_X["right"] + CARD_W / 2, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["right"] + CARD_W / 2, CARD_Y["bottom"] - ARROW_MARGIN)), ("Quantized vs Quantized + Windowed", (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["bottom"] + CARD_H / 2), (CARD_X["right"] - ARROW_MARGIN, CARD_Y["bottom"] + CARD_H / 2)), ("Original vs Quantized + Windowed", (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["right"] - ARROW_MARGIN, CARD_Y["bottom"] - ARROW_MARGIN)), ("Windowed vs Quantized", (CARD_X["right"] - ARROW_MARGIN, CARD_Y["top"] + CARD_H + ARROW_MARGIN), (CARD_X["left"] + CARD_W + ARROW_MARGIN, CARD_Y["bottom"] - ARROW_MARGIN)), ] parts = [f'', defs] for key, (x, y) in cards.items(): png = to_png_base64(images[key]) ty = title_base_y[key] parts.append(f'') parts.append(f'') parts.append(f'{TITLE_MAP[key]}') parts.append(f'Quantization: bit depth = {bit_depth}') parts.append(f'Window: [{low:.1f}, {high:.1f}]') for label, start, end in lines: active = label == selected_compare color = "#C4473A" if active else "#A4AFBA" marker = "arrow-red" if active else "arrow-gray" sx, sy = start ex, ey = end parts.append( f'' f'' f'' f"" ) parts.append("") return "".join(parts) def main() -> None: st.set_page_config(page_title=APP_TITLE, layout="wide") st.title(APP_TITLE) st.caption( "Interactive demo for grayscale quantization and CT windowing. " "Built-in samples are bundled with the repo, and you can also upload your own image, DICOM, NIfTI, or NPZ." ) current_compare = str(st.query_params.get("compare", COMPARE_OPTIONS[0])) if current_compare not in COMPARE_OPTIONS: current_compare = COMPARE_OPTIONS[0] left, right = st.columns([1, 1.7], gap="large") with left: st.subheader("Controls") st.caption("All interactive controls are annotated below. Invalid input should show a friendly error instead of crashing the app.") input_source = st.radio( "Input source", ["Built-in sample", "Upload file"], help="Choose a bundled example from the repository, or upload your own grayscale image / medical scan file.", ) data = None try: if input_source == "Built-in sample": sample_label = st.selectbox( "Built-in sample", list(BUILTIN_SAMPLES.keys()), help="Repository-safe sample assets for Hugging Face deployment.", ) data = load_builtin_sample(sample_label) else: uploaded = st.file_uploader( "Upload image / medical file", type=["png", "jpg", "jpeg", "dcm", "nii", "gz", "npz"], help="Supported formats: PNG, JPG, DICOM (.dcm), NIfTI (.nii/.nii.gz), and NPZ arrays.", ) if uploaded is not None: data = load_uploaded_file(uploaded) except Exception as exc: st.error(f"Failed to load input: {exc}") data = None if data is None: st.info("Choose a built-in sample or upload a file to begin.") return image = data["image"].astype(np.float32) volume = data.get("volume") if volume is not None: default_slice = int(data.get("slice_index") or volume.shape[0] // 2) slice_index = st.slider( "Slice index", min_value=0, max_value=int(volume.shape[0] - 1), value=int(np.clip(default_slice, 0, volume.shape[0] - 1)), help="For 3D volumes, choose which axial slice is visualized.", ) image = volume[slice_index].astype(np.float32) if input_source == "Upload file": rotation = st.select_slider( "Rotate uploaded image", options=[0, 90, 180, 270], value=0, help="Rotate only the displayed 2D slice, which is useful when checking uploaded files.", ) if rotation: image = np.rot90(image, rotation // 90) st.caption(f"Loaded source: {data['label']}") if data.get("note"): st.caption(str(data["note"])) st.caption("Large images are automatically resized for responsive interaction.") bit_depth = st.slider( "Quantization bit depth", min_value=1, max_value=16, value=min(max(int(data.get("bit_depth", 8)), 1), 16), help="Lower bit depth reduces the number of available gray levels before display.", ) preset_default = "Custom" suggested_window = data.get("suggested_window") if suggested_window is not None: for name, bounds in WINDOW_PRESETS.items(): if bounds == tuple(map(float, suggested_window)): preset_default = name break preset = st.selectbox( "CT window preset", list(WINDOW_PRESETS.keys()), index=list(WINDOW_PRESETS.keys()).index(preset_default), help="Choose a preset medical display window, or switch to Custom to edit the lower and upper bounds directly.", ) if preset != "Custom": low_default, high_default = WINDOW_PRESETS[preset] elif suggested_window is not None: low_default, high_default = map(float, suggested_window) else: low_default, high_default = float(np.min(image)), float(np.max(image)) lower = st.number_input( "Window lower bound", value=float(low_default), help="Intensities below this value are clipped before the displayed normalization step.", ) upper = st.number_input( "Window upper bound", value=float(high_default), help="Intensities above this value are clipped before the displayed normalization step.", ) if upper <= lower: st.error("Window upper bound must be greater than lower bound.") return try: original = resize_for_display(image) quantized = quantize_image(original, bit_depth) windowed = apply_window(original, float(lower), float(upper)) quantized_windowed = apply_window(quantized, float(lower), float(upper)) except Exception as exc: st.error(f"Failed to prepare visualization: {exc}") return images = { "original": original, "windowed": windowed, "quantized": quantized, "quantized_windowed": quantized_windowed, } with right: st.subheader("Visualization") st.caption("Each double-arrow line is clickable. The selected relationship changes color and also updates the metric comparison below.") svg = render_visualization_svg(images, bit_depth, float(lower), float(upper), current_compare) components.html(svg, height=720, scrolling=False) st.subheader("Metrics") compare = st.selectbox( "Compare", COMPARE_OPTIONS, index=COMPARE_OPTIONS.index(current_compare), help="Choose which transformation pair to compare numerically. This stays synchronized with the selected arrow relation.", ) if compare != current_compare: st.query_params["compare"] = compare st.rerun() left_img, right_img, left_name, right_name = metric_pair(images, compare) st.caption(f"Current comparison: {left_name} vs {right_name}") metrics = compare_metrics(left_img, right_img) col_a, col_b = st.columns(2) split = (len(metrics) + 1) // 2 with col_a: for name, value in metrics[:split]: st.markdown(f"**{name}** \n{value}") with col_b: for name, value in metrics[split:]: st.markdown(f"**{name}** \n{value}") with st.expander("Deployment note"): st.write( "This Hugging Face Space is public at " f"[{SPACE_URL_PLACEHOLDER}]({SPACE_URL_PLACEHOLDER}). " "Built-in samples are bundled as NPZ assets to avoid raw binary medical files in the repo, " "while external uploads still accept DICOM and NIfTI." ) if __name__ == "__main__": main()