from __future__ import annotations from dataclasses import dataclass from io import BytesIO from pathlib import Path import tempfile import base64 import matplotlib.pyplot as plt from matplotlib.patches import ConnectionPatch import nibabel as nib import numpy as np import pydicom from pydicom.errors import InvalidDicomError from PIL import Image @dataclass(frozen=True) class BuiltinSample: label: str path: Path dataset_name: str dataset_url: str dataset_summary: str input_type_label: str display_note: str rotate_k: int = 0 @dataclass(frozen=True) class LoadedImage: image: np.ndarray source_name: str source_kind: str input_type_label: str note: str original_bit_depth: int volume: np.ndarray | None = None default_slice_index: int | None = None dataset_name: str | None = None dataset_url: str | None = None dataset_summary: str | None = None suggested_window: tuple[float, float] | None = None ROOT = Path(__file__).resolve().parent SAMPLE_IMAGES_ROOT = ROOT / "sample_images" BUILTIN_SAMPLES_ROOT = SAMPLE_IMAGES_ROOT / "builtin_samples" def _ensure_non_empty_2d_image(image: np.ndarray, source_name: str) -> np.ndarray: array = np.asarray(image) if array.ndim != 2: raise ValueError(f"{source_name} did not produce a 2D grayscale image.") if array.size == 0: raise ValueError(f"{source_name} is empty.") return array.astype(np.float32) def _ensure_non_empty_volume(volume: np.ndarray, source_name: str) -> np.ndarray: array = np.asarray(volume) if array.ndim < 3: raise ValueError(f"{source_name} is not a valid 3D volume.") if array.size == 0 or 0 in array.shape: raise ValueError(f"{source_name} contains no readable volume data.") return array.astype(np.float32) def list_builtin_samples() -> dict[str, BuiltinSample]: return { "CT-RATE chest CT volume": BuiltinSample( label="CT-RATE chest CT volume", path=BUILTIN_SAMPLES_ROOT / "ct_rate_valid_100_a_1.npz", dataset_name="CT-RATE", dataset_url="https://huggingface.co/datasets/ibrahimhamamci/CT-RATE", dataset_summary="Chest CT volumes paired with radiology reports, abnormality labels, and metadata.", input_type_label="Built-in NPZ volume export, displayed as an axial slice", display_note="CT-RATE volumes are rotated 90 degrees clockwise in this app so the slice orientation matches the usual viewing convention.", rotate_k=-1, ), "LDCT-and-Projection-data full-dose chest CT": BuiltinSample( label="LDCT-and-Projection-data full-dose chest CT", path=BUILTIN_SAMPLES_ROOT / "ldct_full_dose_1-135.npz", dataset_name="LDCT-and-Projection-data", dataset_url="https://www.cancerimagingarchive.net/collection/ldct-and-projection-data/", dataset_summary="Low-dose CT collection with DICOM image data, DICOM-CT-PD projection data, and supporting clinical reports.", input_type_label="Built-in NPZ export of a DICOM CT image slice", display_note="This built-in sample uses the full-dose reconstructed chest CT image series from the TCIA collection.", ), "RSNA PE Detection CTPA slice": BuiltinSample( label="RSNA PE Detection CTPA slice", path=BUILTIN_SAMPLES_ROOT / "rsna_pe_7bf959bb5c7e.npz", dataset_name="RSNA Pulmonary Embolism Detection", dataset_url="https://registry.opendata.aws/rsna-pulmonary-embolism-detection/", dataset_summary="Annotated CT pulmonary angiography studies with DICOM images and CSV annotations for pulmonary embolism detection.", input_type_label="Built-in NPZ export of a DICOM CTPA slice", display_note="This built-in sample uses one chest CT pulmonary angiography slice from the RSNA PE example series test/0b2c88c2f00a/bec1b1d73f48.", ), } def load_builtin_sample(label: str) -> LoadedImage: samples = list_builtin_samples() sample = samples[label] loaded = load_local_path(sample.path) image = loaded.image volume = loaded.volume if sample.rotate_k: image = np.rot90(image, k=sample.rotate_k) if volume is not None: volume = np.rot90(volume, k=sample.rotate_k, axes=(1, 2)) note = f"{sample.display_note} {loaded.note}" return LoadedImage( image=image, source_name=sample.path.name, source_kind=loaded.source_kind, input_type_label=sample.input_type_label, note=note, original_bit_depth=loaded.original_bit_depth, volume=volume, default_slice_index=loaded.default_slice_index, dataset_name=sample.dataset_name, dataset_url=sample.dataset_url, dataset_summary=sample.dataset_summary, suggested_window=loaded.suggested_window, ) def load_uploaded_file(uploaded_file) -> LoadedImage: data = uploaded_file.getvalue() if not data: raise ValueError("The uploaded file is empty.") lower_name = uploaded_file.name.lower() if lower_name.endswith(".npz"): return load_npz_bytes(data, uploaded_file.name) if lower_name.endswith(".dcm"): try: image, note, suggested_window, volume, default_slice_index, bit_depth = load_dicom_bytes(data) except InvalidDicomError as exc: raise ValueError("This file is not a readable DICOM image.") from exc return LoadedImage( image=image, source_name=uploaded_file.name, source_kind="dicom", input_type_label="Uploaded DICOM image (.dcm)", note=note, original_bit_depth=bit_depth, volume=volume, default_slice_index=default_slice_index, suggested_window=suggested_window, ) if lower_name.endswith(".nii") or lower_name.endswith(".nii.gz"): image, note, volume, default_slice_index, bit_depth = load_nifti_bytes(data, uploaded_file.name) return LoadedImage( image=image, source_name=uploaded_file.name, source_kind="nifti", input_type_label="Uploaded NIfTI volume (.nii or .nii.gz), displayed as an axial slice", note=note, original_bit_depth=bit_depth, volume=volume, default_slice_index=default_slice_index, ) try: pil_image = Image.open(BytesIO(data)) except Exception as exc: raise ValueError("Could not read this image file. Please upload PNG, JPG, DICOM, or NIfTI.") from exc image, note = pil_to_grayscale_array(pil_image) return LoadedImage( image=image, source_name=uploaded_file.name, source_kind="image", input_type_label=f"Uploaded image ({Path(uploaded_file.name).suffix.lower()})", note=note, original_bit_depth=infer_bit_depth_from_array(image), ) def load_local_path(path: Path) -> LoadedImage: if not path.exists(): raise ValueError(f"Built-in sample not found: {path.name}") lower_name = path.name.lower() if lower_name.endswith(".npz"): return load_npz_path(path) if lower_name.endswith(".dcm"): try: image, note, suggested_window, volume, default_slice_index, bit_depth = load_dicom_path(path) except InvalidDicomError as exc: raise ValueError(f"Built-in DICOM sample is not readable: {path.name}") from exc return LoadedImage( image=image, source_name=path.name, source_kind="dicom", input_type_label="Local DICOM image (.dcm)", note=note, original_bit_depth=bit_depth, volume=volume, default_slice_index=default_slice_index, suggested_window=suggested_window, ) if lower_name.endswith(".nii") or lower_name.endswith(".nii.gz"): image, note, volume, default_slice_index, bit_depth = load_nifti_path(path) return LoadedImage( image=image, source_name=path.name, source_kind="nifti", input_type_label="Local NIfTI volume (.nii or .nii.gz), displayed as an axial slice", note=note, original_bit_depth=bit_depth, volume=volume, default_slice_index=default_slice_index, ) try: pil_image = Image.open(path) except Exception as exc: raise ValueError(f"Built-in image sample could not be opened: {path.name}") from exc image, note = pil_to_grayscale_array(pil_image) return LoadedImage( image=image, source_name=path.name, source_kind="image", input_type_label=f"Local image ({path.suffix.lower()})", note=note, original_bit_depth=infer_bit_depth_from_array(image), ) def pil_to_grayscale_array(pil_image): array = np.asarray(pil_image) if array.size == 0: raise ValueError("The uploaded image contains no pixel data.") # Already single-channel if array.ndim == 2: return array, f"Loaded a single-channel grayscale image with dtype {array.dtype}." # RGB or RGBA if array.ndim == 3 and array.shape[2] >= 3: rgb = array[..., :3] # Case 1: 3-channel grayscale image if np.array_equal(rgb[..., 0], rgb[..., 1]) and np.array_equal(rgb[..., 1], rgb[..., 2]): gray = rgb[..., 0] return gray, "Detected a 3-channel grayscale image and collapsed it to one channel." # Case 2: true RGB image rgb_float = rgb.astype(np.float32) gray_float = ( 0.299 * rgb_float[..., 0] + 0.587 * rgb_float[..., 1] + 0.114 * rgb_float[..., 2] ) # For ordinary 8-bit PNG/JPG, return 8-bit grayscale if rgb.dtype == np.uint8: gray = np.clip(np.rint(gray_float), 0, 255).astype(np.uint8) return gray, "Converted an 8-bit RGB image to one 8-bit luminance grayscale channel." # For unusual higher-bit RGB images, preserve numeric range but avoid fake precision if possible gray = gray_float.astype(np.float32) return gray, f"Converted an RGB image with dtype {rgb.dtype} to one grayscale channel." raise ValueError("Unsupported image shape. Please upload a grayscale or RGB image.") def infer_bit_depth_from_array(array: np.ndarray) -> int: dtype = np.asarray(array).dtype if np.issubdtype(dtype, np.integer): return max(1, min(int(dtype.itemsize * 8), 16)) return 8 def _parse_npz_payload(payload: dict[str, np.ndarray], source_name: str) -> LoadedImage: image = payload.get("image") volume = payload.get("volume") if volume is not None: volume = _ensure_non_empty_volume(volume, source_name) default_slice_index = int(np.atleast_1d(payload.get("default_slice_index", [volume.shape[0] // 2]))[0]) default_slice_index = min(max(default_slice_index, 0), volume.shape[0] - 1) image = _ensure_non_empty_2d_image(volume[default_slice_index], source_name) elif image is not None: image = _ensure_non_empty_2d_image(image, source_name) default_slice_index = None else: raise ValueError(f"{source_name} NPZ file does not contain 'image' or 'volume'.") bit_depth = int(np.atleast_1d(payload.get("original_bit_depth", [infer_bit_depth_from_array(image)]))[0]) suggested_window = None if "suggested_window" in payload: bounds = np.asarray(payload["suggested_window"], dtype=np.float32).ravel() if bounds.size >= 2: suggested_window = float(bounds[0]), float(bounds[1]) note = str(np.atleast_1d(payload.get("note", ["Loaded a repository-safe NPZ sample."]))[0]) source_kind = str(np.atleast_1d(payload.get("source_kind", ["npz"]))[0]) return LoadedImage( image=image, source_name=source_name, source_kind=source_kind, input_type_label="NPZ image/volume", note=note, original_bit_depth=bit_depth, volume=volume, default_slice_index=default_slice_index, suggested_window=suggested_window, ) def load_npz_bytes(data: bytes, filename: str) -> LoadedImage: try: payload = dict(np.load(BytesIO(data), allow_pickle=True)) except Exception as exc: raise ValueError(f"Could not read NPZ file: {filename}") from exc return _parse_npz_payload(payload, filename) def load_npz_path(path: Path) -> LoadedImage: try: payload = dict(np.load(path, allow_pickle=True)) except Exception as exc: raise ValueError(f"Could not read NPZ file: {path.name}") from exc return _parse_npz_payload(payload, path.name) def load_dicom_bytes(data: bytes) -> tuple[np.ndarray, str, tuple[float, float] | None, np.ndarray | None, int | None, int]: dataset = pydicom.dcmread(BytesIO(data)) return load_dicom_dataset(dataset) def load_dicom_path(path: Path) -> tuple[np.ndarray, str, tuple[float, float] | None, np.ndarray | None, int | None, int]: dataset = pydicom.dcmread(str(path)) return load_dicom_dataset(dataset) def load_dicom_dataset(dataset) -> tuple[np.ndarray, str, tuple[float, float] | None, np.ndarray | None, int | None, int]: if not hasattr(dataset, "PixelData"): raise ValueError("DICOM file has no readable pixel data.") raw_pixels = dataset.pixel_array if np.asarray(raw_pixels).size == 0: raise ValueError("DICOM file contains empty pixel data.") bit_depth = int(getattr(dataset, "BitsStored", raw_pixels.dtype.itemsize * 8)) pixels = raw_pixels.astype(np.float32) volume = None default_slice_index = None if pixels.ndim > 2: volume = _ensure_non_empty_volume(pixels, "DICOM volume") default_slice_index = pixels.shape[0] // 2 pixels = volume[default_slice_index] pixels = _ensure_non_empty_2d_image(pixels, "DICOM image") slope = float(getattr(dataset, "RescaleSlope", 1.0)) intercept = float(getattr(dataset, "RescaleIntercept", 0.0)) pixels = pixels * slope + intercept if volume is not None: volume = volume * slope + intercept suggested_window = extract_dicom_window(dataset) note = ( f"Loaded DICOM pixel data with rescale slope {slope:g} and intercept {intercept:g}. " "Windowing sliders operate in the uploaded image's native intensity range." ) return pixels, note, suggested_window, volume, default_slice_index, bit_depth def extract_dicom_window(dataset) -> tuple[float, float] | None: center = getattr(dataset, "WindowCenter", None) width = getattr(dataset, "WindowWidth", None) if center is None or width is None: return None if isinstance(center, pydicom.multival.MultiValue): center = float(center[0]) else: center = float(center) if isinstance(width, pydicom.multival.MultiValue): width = float(width[0]) else: width = float(width) low = center - width / 2.0 high = center + width / 2.0 return float(low), float(high) def load_nifti_bytes(data: bytes, filename: str) -> tuple[np.ndarray, str, np.ndarray, int, int]: suffix = ".nii.gz" if filename.lower().endswith(".nii.gz") else ".nii" with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file: tmp_file.write(data) tmp_file.flush() return load_nifti_path(Path(tmp_file.name)) def load_nifti_path(path: Path) -> tuple[np.ndarray, str, np.ndarray, int, int]: try: image = nib.load(str(path)) except Exception as exc: raise ValueError(f"Could not read NIfTI volume: {path.name}") from exc volume = np.asanyarray(image.dataobj, dtype=np.float32) data_dtype = image.get_data_dtype() bit_depth = int(getattr(data_dtype, "itemsize", 1) * 8) volume = _ensure_non_empty_volume(volume, path.name) volume = np.moveaxis(volume, -1, 0) middle_index = volume.shape[0] // 2 slice_2d = _ensure_non_empty_2d_image(volume[middle_index], path.name) note = f"Loaded a 3D NIfTI volume and selected the default middle axial slice (index {middle_index})." return slice_2d, note, volume, middle_index, bit_depth def create_histogram_figure( image_top: np.ndarray, image_bottom: np.ndarray, top_label: str, bottom_label: str, top_bounds: tuple[float, float], bottom_bounds: tuple[float, float], ) -> plt.Figure: stacked = np.concatenate([np.asarray(image_top).ravel(), np.asarray(image_bottom).ravel()]) bins = np.linspace(float(stacked.min()), float(stacked.max()), 129) fig, axes = plt.subplots(2, 1, figsize=(6, 5.0), sharex=True) for ax, image, label, bounds in zip( axes, [image_top, image_bottom], [top_label, bottom_label], [top_bounds, bottom_bounds], ): low, high = bounds ax.hist(np.asarray(image).ravel(), bins=bins, color="#4C6A92", alpha=0.9) ax.axvline(low, color="#B85450", linestyle="--", linewidth=1.4, label="Lower bound") ax.axvline(high, color="#B85450", linestyle="--", linewidth=1.4, label="Upper bound") ax.set_ylabel("Count") ax.set_title(label, fontsize=10, loc="left") ax.legend(loc="upper right", fontsize=8) axes[-1].set_xlabel("Intensity/HU Value") fig.suptitle("Intensity Histogram Comparison", fontsize=12) fig.tight_layout() return fig def create_mse_figure(reference: np.ndarray, test: np.ndarray) -> plt.Figure: squared_error = (np.asarray(test, dtype=np.float32) - np.asarray(reference, dtype=np.float32)) ** 2 vmax = float(np.percentile(squared_error, 99)) vmax = max(vmax, 1e-6) fig, ax = plt.subplots(figsize=(4.5, 4.5)) im = ax.imshow(squared_error, cmap="gray", vmin=0.0, vmax=vmax) ax.set_title("MSE Map") ax.axis("off") fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) fig.tight_layout() return fig def create_visualization_figure( original_display: np.ndarray, windowed_display: np.ndarray, quantized_display: np.ndarray, quantized_windowed_display: np.ndarray, bit_depth: int, low: float, high: float, metric_mode: str, ) -> plt.Figure: fig, axes = plt.subplots(2, 2, figsize=(10, 10)) ax_tl, ax_tr = axes[0, 0], axes[0, 1] ax_bl, ax_br = axes[1, 0], axes[1, 1] panels = [ (ax_tl, original_display, "Original Display", "Quantization: None", "Window: None"), (ax_tr, windowed_display, "Windowed Display", "Quantization: None", f"Window: [{low:.1f}, {high:.1f}]"), (ax_bl, quantized_display, "Quantized Display", f"Quantization: bit depth = {bit_depth}", "Window: None"), (ax_br, quantized_windowed_display, "Quantized + Windowed Display", f"Quantization: bit depth = {bit_depth}", f"Window: [{low:.1f}, {high:.1f}]"), ] for ax, image, title, line1, line2 in panels: ax.imshow(image, cmap="gray", vmin=0.0, vmax=1.0) ax.set_title(title, fontsize=11, pad=8) ax.axis("off") ax.text(0.0, -0.08, line1, transform=ax.transAxes, ha="left", va="top", fontsize=9) ax.text(0.0, -0.16, line2, transform=ax.transAxes, ha="left", va="top", fontsize=9) pair_to_axes = { "Original vs Windowed": (ax_tl, ax_tr), "Original vs Quantized": (ax_tl, ax_bl), "Windowed vs Quantized + Windowed": (ax_tr, ax_br), "Quantized vs Quantized + Windowed": (ax_bl, ax_br), "Original vs Quantized + Windowed": (ax_tl, ax_br), "Windowed vs Quantized": (ax_tr, ax_bl), } line_specs = [ ("Original vs Windowed", ax_tl, ax_tr), ("Original vs Quantized", ax_tl, ax_bl), ("Windowed vs Quantized + Windowed", ax_tr, ax_br), ("Quantized vs Quantized + Windowed", ax_bl, ax_br), ("Original vs Quantized + Windowed", ax_tl, ax_br), ("Windowed vs Quantized", ax_tr, ax_bl), ] for name, ax_a, ax_b in line_specs: is_active = name == metric_mode color = "#C4473A" if is_active else "#AAB7C4" linewidth = 4.0 if is_active else 1.8 connection = ConnectionPatch( xyA=(0.5, 0.5), coordsA=ax_a.transAxes, xyB=(0.5, 0.5), coordsB=ax_b.transAxes, color=color, linewidth=linewidth, zorder=0, alpha=0.95 if is_active else 0.8, ) fig.add_artist(connection) fig.tight_layout(pad=2.0) return fig def _to_base64_grayscale_png(image: np.ndarray) -> str: image_uint8 = np.clip(np.asarray(image) * 255.0, 0, 255).astype(np.uint8) pil_image = Image.fromarray(image_uint8, mode="L") buffer = BytesIO() pil_image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("ascii") def create_visualization_svg_html( original_display: np.ndarray, windowed_display: np.ndarray, quantized_display: np.ndarray, quantized_windowed_display: np.ndarray, bit_depth: int, low: float, high: float, metric_mode: str, ) -> str: images = { "original": _to_base64_grayscale_png(original_display), "windowed": _to_base64_grayscale_png(windowed_display), "quantized": _to_base64_grayscale_png(quantized_display), "quantized_windowed": _to_base64_grayscale_png(quantized_windowed_display), } panels = { "original": {"x": 70, "y": 40, "title": "Original Display", "q": "Quantization: None", "w": "Window: None"}, "windowed": {"x": 560, "y": 40, "title": "Windowed Display", "q": "Quantization: None", "w": f"Window: [{low:.1f}, {high:.1f}]"}, "quantized": {"x": 70, "y": 510, "title": "Quantized Display", "q": f"Quantization: bit depth = {bit_depth}", "w": "Window: None"}, "quantized_windowed": {"x": 560, "y": 510, "title": "Quantized + Windowed Display", "q": f"Quantization: bit depth = {bit_depth}", "w": f"Window: [{low:.1f}, {high:.1f}]"}, } panel_w = 360 panel_h = 360 arrow_margin = 18 comparisons = [ ("Original vs Windowed", "original", "windowed"), ("Original vs Quantized", "original", "quantized"), ("Windowed vs Quantized + Windowed", "windowed", "quantized_windowed"), ("Quantized vs Quantized + Windowed", "quantized", "quantized_windowed"), ("Original vs Quantized + Windowed", "original", "quantized_windowed"), ("Windowed vs Quantized", "windowed", "quantized"), ] def center(panel_key: str) -> tuple[float, float]: p = panels[panel_key] return p["x"] + panel_w / 2.0, p["y"] + panel_h / 2.0 def edge_points(src: str, dst: str) -> tuple[float, float, float, float]: src_panel = panels[src] dst_panel = panels[dst] src_cx, src_cy = center(src) dst_cx, dst_cy = center(dst) if abs(src_cy - dst_cy) < 1e-6: if src_cx < dst_cx: return src_panel["x"] + panel_w + arrow_margin, src_cy, dst_panel["x"] - arrow_margin, dst_cy return src_panel["x"] - arrow_margin, src_cy, dst_panel["x"] + panel_w + arrow_margin, dst_cy if abs(src_cx - dst_cx) < 1e-6: if src_cy < dst_cy: return src_cx, src_panel["y"] + panel_h + arrow_margin, dst_cx, dst_panel["y"] - arrow_margin return src_cx, src_panel["y"] - arrow_margin, dst_cx, dst_panel["y"] + panel_h + arrow_margin if src_cx < dst_cx and src_cy < dst_cy: return ( src_panel["x"] + panel_w + arrow_margin, src_panel["y"] + panel_h + arrow_margin, dst_panel["x"] - arrow_margin, dst_panel["y"] - arrow_margin, ) if src_cx > dst_cx and src_cy < dst_cy: return ( src_panel["x"] - arrow_margin, src_panel["y"] + panel_h + arrow_margin, dst_panel["x"] + panel_w + arrow_margin, dst_panel["y"] - arrow_margin, ) if src_cx < dst_cx and src_cy > dst_cy: return ( src_panel["x"] + panel_w + arrow_margin, src_panel["y"] - arrow_margin, dst_panel["x"] - arrow_margin, dst_panel["y"] + panel_h + arrow_margin, ) return ( src_panel["x"] - arrow_margin, src_panel["y"] - arrow_margin, dst_panel["x"] + panel_w + arrow_margin, dst_panel["y"] + panel_h + arrow_margin, ) def arrow_group(name: str, src: str, dst: str) -> str: x1, y1, x2, y2 = edge_points(src, dst) active = name == metric_mode stroke = "#C4473A" if active else "#B4BEC8" width = "4" opacity = "1.0" if active else "0.85" marker_id = "arrowhead-active" if active else "arrowhead-inactive" return f""" """ panel_svg_parts = [] for key, panel in panels.items(): panel_svg_parts.append( f""" {panel['title']} {panel['q']} {panel['w']} """ ) arrow_svg_parts = [arrow_group(name, src, dst) for name, src, dst in comparisons] return f"""
{''.join(panel_svg_parts)} {''.join(arrow_svg_parts)}
"""