| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from io import BytesIO |
| from pathlib import Path |
| import tempfile |
| import base64 |
|
|
| import matplotlib.pyplot as plt |
| from matplotlib.patches import ConnectionPatch |
| import nibabel as nib |
| import numpy as np |
| import pydicom |
| from pydicom.errors import InvalidDicomError |
| from PIL import Image |
|
|
|
|
| @dataclass(frozen=True) |
| class BuiltinSample: |
| label: str |
| path: Path |
| dataset_name: str |
| dataset_url: str |
| dataset_summary: str |
| input_type_label: str |
| display_note: str |
| rotate_k: int = 0 |
|
|
|
|
| @dataclass(frozen=True) |
| class LoadedImage: |
| image: np.ndarray |
| source_name: str |
| source_kind: str |
| input_type_label: str |
| note: str |
| original_bit_depth: int |
| volume: np.ndarray | None = None |
| default_slice_index: int | None = None |
| dataset_name: str | None = None |
| dataset_url: str | None = None |
| dataset_summary: str | None = None |
| suggested_window: tuple[float, float] | None = None |
|
|
|
|
| ROOT = Path(__file__).resolve().parent |
| SAMPLE_IMAGES_ROOT = ROOT / "sample_images" |
| BUILTIN_SAMPLES_ROOT = SAMPLE_IMAGES_ROOT / "builtin_samples" |
|
|
|
|
| def _ensure_non_empty_2d_image(image: np.ndarray, source_name: str) -> np.ndarray: |
| array = np.asarray(image) |
| if array.ndim != 2: |
| raise ValueError(f"{source_name} did not produce a 2D grayscale image.") |
| if array.size == 0: |
| raise ValueError(f"{source_name} is empty.") |
| return array.astype(np.float32) |
|
|
|
|
| def _ensure_non_empty_volume(volume: np.ndarray, source_name: str) -> np.ndarray: |
| array = np.asarray(volume) |
| if array.ndim < 3: |
| raise ValueError(f"{source_name} is not a valid 3D volume.") |
| if array.size == 0 or 0 in array.shape: |
| raise ValueError(f"{source_name} contains no readable volume data.") |
| return array.astype(np.float32) |
|
|
|
|
| def list_builtin_samples() -> dict[str, BuiltinSample]: |
| return { |
| "CT-RATE chest CT volume": BuiltinSample( |
| label="CT-RATE chest CT volume", |
| path=BUILTIN_SAMPLES_ROOT / "ct_rate_valid_100_a_1.npz", |
| dataset_name="CT-RATE", |
| dataset_url="https://huggingface.co/datasets/ibrahimhamamci/CT-RATE", |
| dataset_summary="Chest CT volumes paired with radiology reports, abnormality labels, and metadata.", |
| input_type_label="Built-in NPZ volume export, displayed as an axial slice", |
| display_note="CT-RATE volumes are rotated 90 degrees clockwise in this app so the slice orientation matches the usual viewing convention.", |
| rotate_k=-1, |
| ), |
| "LDCT-and-Projection-data full-dose chest CT": BuiltinSample( |
| label="LDCT-and-Projection-data full-dose chest CT", |
| path=BUILTIN_SAMPLES_ROOT / "ldct_full_dose_1-135.npz", |
| dataset_name="LDCT-and-Projection-data", |
| dataset_url="https://www.cancerimagingarchive.net/collection/ldct-and-projection-data/", |
| dataset_summary="Low-dose CT collection with DICOM image data, DICOM-CT-PD projection data, and supporting clinical reports.", |
| input_type_label="Built-in NPZ export of a DICOM CT image slice", |
| display_note="This built-in sample uses the full-dose reconstructed chest CT image series from the TCIA collection.", |
| ), |
| "RSNA PE Detection CTPA slice": BuiltinSample( |
| label="RSNA PE Detection CTPA slice", |
| path=BUILTIN_SAMPLES_ROOT / "rsna_pe_7bf959bb5c7e.npz", |
| dataset_name="RSNA Pulmonary Embolism Detection", |
| dataset_url="https://registry.opendata.aws/rsna-pulmonary-embolism-detection/", |
| dataset_summary="Annotated CT pulmonary angiography studies with DICOM images and CSV annotations for pulmonary embolism detection.", |
| input_type_label="Built-in NPZ export of a DICOM CTPA slice", |
| display_note="This built-in sample uses one chest CT pulmonary angiography slice from the RSNA PE example series test/0b2c88c2f00a/bec1b1d73f48.", |
| ), |
| } |
|
|
|
|
| def load_builtin_sample(label: str) -> LoadedImage: |
| samples = list_builtin_samples() |
| sample = samples[label] |
| loaded = load_local_path(sample.path) |
| image = loaded.image |
| volume = loaded.volume |
| if sample.rotate_k: |
| image = np.rot90(image, k=sample.rotate_k) |
| if volume is not None: |
| volume = np.rot90(volume, k=sample.rotate_k, axes=(1, 2)) |
| note = f"{sample.display_note} {loaded.note}" |
| return LoadedImage( |
| image=image, |
| source_name=sample.path.name, |
| source_kind=loaded.source_kind, |
| input_type_label=sample.input_type_label, |
| note=note, |
| original_bit_depth=loaded.original_bit_depth, |
| volume=volume, |
| default_slice_index=loaded.default_slice_index, |
| dataset_name=sample.dataset_name, |
| dataset_url=sample.dataset_url, |
| dataset_summary=sample.dataset_summary, |
| suggested_window=loaded.suggested_window, |
| ) |
|
|
|
|
| def load_uploaded_file(uploaded_file) -> LoadedImage: |
| data = uploaded_file.getvalue() |
| if not data: |
| raise ValueError("The uploaded file is empty.") |
| lower_name = uploaded_file.name.lower() |
|
|
| if lower_name.endswith(".npz"): |
| return load_npz_bytes(data, uploaded_file.name) |
|
|
| if lower_name.endswith(".dcm"): |
| try: |
| image, note, suggested_window, volume, default_slice_index, bit_depth = load_dicom_bytes(data) |
| except InvalidDicomError as exc: |
| raise ValueError("This file is not a readable DICOM image.") from exc |
| return LoadedImage( |
| image=image, |
| source_name=uploaded_file.name, |
| source_kind="dicom", |
| input_type_label="Uploaded DICOM image (.dcm)", |
| note=note, |
| original_bit_depth=bit_depth, |
| volume=volume, |
| default_slice_index=default_slice_index, |
| suggested_window=suggested_window, |
| ) |
|
|
| if lower_name.endswith(".nii") or lower_name.endswith(".nii.gz"): |
| image, note, volume, default_slice_index, bit_depth = load_nifti_bytes(data, uploaded_file.name) |
| return LoadedImage( |
| image=image, |
| source_name=uploaded_file.name, |
| source_kind="nifti", |
| input_type_label="Uploaded NIfTI volume (.nii or .nii.gz), displayed as an axial slice", |
| note=note, |
| original_bit_depth=bit_depth, |
| volume=volume, |
| default_slice_index=default_slice_index, |
| ) |
|
|
| try: |
| pil_image = Image.open(BytesIO(data)) |
| except Exception as exc: |
| raise ValueError("Could not read this image file. Please upload PNG, JPG, DICOM, or NIfTI.") from exc |
| image, note = pil_to_grayscale_array(pil_image) |
| return LoadedImage( |
| image=image, |
| source_name=uploaded_file.name, |
| source_kind="image", |
| input_type_label=f"Uploaded image ({Path(uploaded_file.name).suffix.lower()})", |
| note=note, |
| original_bit_depth=infer_bit_depth_from_array(image), |
| ) |
|
|
|
|
| def load_local_path(path: Path) -> LoadedImage: |
| if not path.exists(): |
| raise ValueError(f"Built-in sample not found: {path.name}") |
| lower_name = path.name.lower() |
| if lower_name.endswith(".npz"): |
| return load_npz_path(path) |
| if lower_name.endswith(".dcm"): |
| try: |
| image, note, suggested_window, volume, default_slice_index, bit_depth = load_dicom_path(path) |
| except InvalidDicomError as exc: |
| raise ValueError(f"Built-in DICOM sample is not readable: {path.name}") from exc |
| return LoadedImage( |
| image=image, |
| source_name=path.name, |
| source_kind="dicom", |
| input_type_label="Local DICOM image (.dcm)", |
| note=note, |
| original_bit_depth=bit_depth, |
| volume=volume, |
| default_slice_index=default_slice_index, |
| suggested_window=suggested_window, |
| ) |
|
|
| if lower_name.endswith(".nii") or lower_name.endswith(".nii.gz"): |
| image, note, volume, default_slice_index, bit_depth = load_nifti_path(path) |
| return LoadedImage( |
| image=image, |
| source_name=path.name, |
| source_kind="nifti", |
| input_type_label="Local NIfTI volume (.nii or .nii.gz), displayed as an axial slice", |
| note=note, |
| original_bit_depth=bit_depth, |
| volume=volume, |
| default_slice_index=default_slice_index, |
| ) |
|
|
| try: |
| pil_image = Image.open(path) |
| except Exception as exc: |
| raise ValueError(f"Built-in image sample could not be opened: {path.name}") from exc |
| image, note = pil_to_grayscale_array(pil_image) |
| return LoadedImage( |
| image=image, |
| source_name=path.name, |
| source_kind="image", |
| input_type_label=f"Local image ({path.suffix.lower()})", |
| note=note, |
| original_bit_depth=infer_bit_depth_from_array(image), |
| ) |
|
|
|
|
| def pil_to_grayscale_array(pil_image): |
| array = np.asarray(pil_image) |
|
|
| if array.size == 0: |
| raise ValueError("The uploaded image contains no pixel data.") |
|
|
| |
| if array.ndim == 2: |
| return array, f"Loaded a single-channel grayscale image with dtype {array.dtype}." |
|
|
| |
| if array.ndim == 3 and array.shape[2] >= 3: |
| rgb = array[..., :3] |
|
|
| |
| if np.array_equal(rgb[..., 0], rgb[..., 1]) and np.array_equal(rgb[..., 1], rgb[..., 2]): |
| gray = rgb[..., 0] |
| return gray, "Detected a 3-channel grayscale image and collapsed it to one channel." |
|
|
| |
| rgb_float = rgb.astype(np.float32) |
| gray_float = ( |
| 0.299 * rgb_float[..., 0] |
| + 0.587 * rgb_float[..., 1] |
| + 0.114 * rgb_float[..., 2] |
| ) |
|
|
| |
| if rgb.dtype == np.uint8: |
| gray = np.clip(np.rint(gray_float), 0, 255).astype(np.uint8) |
| return gray, "Converted an 8-bit RGB image to one 8-bit luminance grayscale channel." |
|
|
| |
| gray = gray_float.astype(np.float32) |
| return gray, f"Converted an RGB image with dtype {rgb.dtype} to one grayscale channel." |
|
|
| raise ValueError("Unsupported image shape. Please upload a grayscale or RGB image.") |
|
|
| def infer_bit_depth_from_array(array: np.ndarray) -> int: |
| dtype = np.asarray(array).dtype |
| if np.issubdtype(dtype, np.integer): |
| return max(1, min(int(dtype.itemsize * 8), 16)) |
| return 8 |
|
|
|
|
| def _parse_npz_payload(payload: dict[str, np.ndarray], source_name: str) -> LoadedImage: |
| image = payload.get("image") |
| volume = payload.get("volume") |
| if volume is not None: |
| volume = _ensure_non_empty_volume(volume, source_name) |
| default_slice_index = int(np.atleast_1d(payload.get("default_slice_index", [volume.shape[0] // 2]))[0]) |
| default_slice_index = min(max(default_slice_index, 0), volume.shape[0] - 1) |
| image = _ensure_non_empty_2d_image(volume[default_slice_index], source_name) |
| elif image is not None: |
| image = _ensure_non_empty_2d_image(image, source_name) |
| default_slice_index = None |
| else: |
| raise ValueError(f"{source_name} NPZ file does not contain 'image' or 'volume'.") |
|
|
| bit_depth = int(np.atleast_1d(payload.get("original_bit_depth", [infer_bit_depth_from_array(image)]))[0]) |
| suggested_window = None |
| if "suggested_window" in payload: |
| bounds = np.asarray(payload["suggested_window"], dtype=np.float32).ravel() |
| if bounds.size >= 2: |
| suggested_window = float(bounds[0]), float(bounds[1]) |
| note = str(np.atleast_1d(payload.get("note", ["Loaded a repository-safe NPZ sample."]))[0]) |
| source_kind = str(np.atleast_1d(payload.get("source_kind", ["npz"]))[0]) |
|
|
| return LoadedImage( |
| image=image, |
| source_name=source_name, |
| source_kind=source_kind, |
| input_type_label="NPZ image/volume", |
| note=note, |
| original_bit_depth=bit_depth, |
| volume=volume, |
| default_slice_index=default_slice_index, |
| suggested_window=suggested_window, |
| ) |
|
|
|
|
| def load_npz_bytes(data: bytes, filename: str) -> LoadedImage: |
| try: |
| payload = dict(np.load(BytesIO(data), allow_pickle=True)) |
| except Exception as exc: |
| raise ValueError(f"Could not read NPZ file: {filename}") from exc |
| return _parse_npz_payload(payload, filename) |
|
|
|
|
| def load_npz_path(path: Path) -> LoadedImage: |
| try: |
| payload = dict(np.load(path, allow_pickle=True)) |
| except Exception as exc: |
| raise ValueError(f"Could not read NPZ file: {path.name}") from exc |
| return _parse_npz_payload(payload, path.name) |
|
|
|
|
| def load_dicom_bytes(data: bytes) -> tuple[np.ndarray, str, tuple[float, float] | None, np.ndarray | None, int | None, int]: |
| dataset = pydicom.dcmread(BytesIO(data)) |
| return load_dicom_dataset(dataset) |
|
|
|
|
| def load_dicom_path(path: Path) -> tuple[np.ndarray, str, tuple[float, float] | None, np.ndarray | None, int | None, int]: |
| dataset = pydicom.dcmread(str(path)) |
| return load_dicom_dataset(dataset) |
|
|
|
|
| def load_dicom_dataset(dataset) -> tuple[np.ndarray, str, tuple[float, float] | None, np.ndarray | None, int | None, int]: |
| if not hasattr(dataset, "PixelData"): |
| raise ValueError("DICOM file has no readable pixel data.") |
| raw_pixels = dataset.pixel_array |
| if np.asarray(raw_pixels).size == 0: |
| raise ValueError("DICOM file contains empty pixel data.") |
| bit_depth = int(getattr(dataset, "BitsStored", raw_pixels.dtype.itemsize * 8)) |
| pixels = raw_pixels.astype(np.float32) |
| volume = None |
| default_slice_index = None |
| if pixels.ndim > 2: |
| volume = _ensure_non_empty_volume(pixels, "DICOM volume") |
| default_slice_index = pixels.shape[0] // 2 |
| pixels = volume[default_slice_index] |
| pixels = _ensure_non_empty_2d_image(pixels, "DICOM image") |
|
|
| slope = float(getattr(dataset, "RescaleSlope", 1.0)) |
| intercept = float(getattr(dataset, "RescaleIntercept", 0.0)) |
| pixels = pixels * slope + intercept |
| if volume is not None: |
| volume = volume * slope + intercept |
| suggested_window = extract_dicom_window(dataset) |
|
|
| note = ( |
| f"Loaded DICOM pixel data with rescale slope {slope:g} and intercept {intercept:g}. " |
| "Windowing sliders operate in the uploaded image's native intensity range." |
| ) |
| return pixels, note, suggested_window, volume, default_slice_index, bit_depth |
|
|
|
|
| def extract_dicom_window(dataset) -> tuple[float, float] | None: |
| center = getattr(dataset, "WindowCenter", None) |
| width = getattr(dataset, "WindowWidth", None) |
| if center is None or width is None: |
| return None |
|
|
| if isinstance(center, pydicom.multival.MultiValue): |
| center = float(center[0]) |
| else: |
| center = float(center) |
|
|
| if isinstance(width, pydicom.multival.MultiValue): |
| width = float(width[0]) |
| else: |
| width = float(width) |
|
|
| low = center - width / 2.0 |
| high = center + width / 2.0 |
| return float(low), float(high) |
|
|
|
|
| def load_nifti_bytes(data: bytes, filename: str) -> tuple[np.ndarray, str, np.ndarray, int, int]: |
| suffix = ".nii.gz" if filename.lower().endswith(".nii.gz") else ".nii" |
| with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file: |
| tmp_file.write(data) |
| tmp_file.flush() |
| return load_nifti_path(Path(tmp_file.name)) |
|
|
|
|
| def load_nifti_path(path: Path) -> tuple[np.ndarray, str, np.ndarray, int, int]: |
| try: |
| image = nib.load(str(path)) |
| except Exception as exc: |
| raise ValueError(f"Could not read NIfTI volume: {path.name}") from exc |
| volume = np.asanyarray(image.dataobj, dtype=np.float32) |
| data_dtype = image.get_data_dtype() |
| bit_depth = int(getattr(data_dtype, "itemsize", 1) * 8) |
| volume = _ensure_non_empty_volume(volume, path.name) |
| volume = np.moveaxis(volume, -1, 0) |
| middle_index = volume.shape[0] // 2 |
| slice_2d = _ensure_non_empty_2d_image(volume[middle_index], path.name) |
| note = f"Loaded a 3D NIfTI volume and selected the default middle axial slice (index {middle_index})." |
| return slice_2d, note, volume, middle_index, bit_depth |
|
|
|
|
| def create_histogram_figure( |
| image_top: np.ndarray, |
| image_bottom: np.ndarray, |
| top_label: str, |
| bottom_label: str, |
| top_bounds: tuple[float, float], |
| bottom_bounds: tuple[float, float], |
| ) -> plt.Figure: |
| stacked = np.concatenate([np.asarray(image_top).ravel(), np.asarray(image_bottom).ravel()]) |
| bins = np.linspace(float(stacked.min()), float(stacked.max()), 129) |
|
|
| fig, axes = plt.subplots(2, 1, figsize=(6, 5.0), sharex=True) |
| for ax, image, label, bounds in zip( |
| axes, |
| [image_top, image_bottom], |
| [top_label, bottom_label], |
| [top_bounds, bottom_bounds], |
| ): |
| low, high = bounds |
| ax.hist(np.asarray(image).ravel(), bins=bins, color="#4C6A92", alpha=0.9) |
| ax.axvline(low, color="#B85450", linestyle="--", linewidth=1.4, label="Lower bound") |
| ax.axvline(high, color="#B85450", linestyle="--", linewidth=1.4, label="Upper bound") |
| ax.set_ylabel("Count") |
| ax.set_title(label, fontsize=10, loc="left") |
| ax.legend(loc="upper right", fontsize=8) |
| axes[-1].set_xlabel("Intensity/HU Value") |
| fig.suptitle("Intensity Histogram Comparison", fontsize=12) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def create_mse_figure(reference: np.ndarray, test: np.ndarray) -> plt.Figure: |
| squared_error = (np.asarray(test, dtype=np.float32) - np.asarray(reference, dtype=np.float32)) ** 2 |
| vmax = float(np.percentile(squared_error, 99)) |
| vmax = max(vmax, 1e-6) |
|
|
| fig, ax = plt.subplots(figsize=(4.5, 4.5)) |
| im = ax.imshow(squared_error, cmap="gray", vmin=0.0, vmax=vmax) |
| ax.set_title("MSE Map") |
| ax.axis("off") |
| fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def create_visualization_figure( |
| original_display: np.ndarray, |
| windowed_display: np.ndarray, |
| quantized_display: np.ndarray, |
| quantized_windowed_display: np.ndarray, |
| bit_depth: int, |
| low: float, |
| high: float, |
| metric_mode: str, |
| ) -> plt.Figure: |
| fig, axes = plt.subplots(2, 2, figsize=(10, 10)) |
| ax_tl, ax_tr = axes[0, 0], axes[0, 1] |
| ax_bl, ax_br = axes[1, 0], axes[1, 1] |
|
|
| panels = [ |
| (ax_tl, original_display, "Original Display", "Quantization: None", "Window: None"), |
| (ax_tr, windowed_display, "Windowed Display", "Quantization: None", f"Window: [{low:.1f}, {high:.1f}]"), |
| (ax_bl, quantized_display, "Quantized Display", f"Quantization: bit depth = {bit_depth}", "Window: None"), |
| (ax_br, quantized_windowed_display, "Quantized + Windowed Display", f"Quantization: bit depth = {bit_depth}", f"Window: [{low:.1f}, {high:.1f}]"), |
| ] |
|
|
| for ax, image, title, line1, line2 in panels: |
| ax.imshow(image, cmap="gray", vmin=0.0, vmax=1.0) |
| ax.set_title(title, fontsize=11, pad=8) |
| ax.axis("off") |
| ax.text(0.0, -0.08, line1, transform=ax.transAxes, ha="left", va="top", fontsize=9) |
| ax.text(0.0, -0.16, line2, transform=ax.transAxes, ha="left", va="top", fontsize=9) |
|
|
| pair_to_axes = { |
| "Original vs Windowed": (ax_tl, ax_tr), |
| "Original vs Quantized": (ax_tl, ax_bl), |
| "Windowed vs Quantized + Windowed": (ax_tr, ax_br), |
| "Quantized vs Quantized + Windowed": (ax_bl, ax_br), |
| "Original vs Quantized + Windowed": (ax_tl, ax_br), |
| "Windowed vs Quantized": (ax_tr, ax_bl), |
| } |
|
|
| line_specs = [ |
| ("Original vs Windowed", ax_tl, ax_tr), |
| ("Original vs Quantized", ax_tl, ax_bl), |
| ("Windowed vs Quantized + Windowed", ax_tr, ax_br), |
| ("Quantized vs Quantized + Windowed", ax_bl, ax_br), |
| ("Original vs Quantized + Windowed", ax_tl, ax_br), |
| ("Windowed vs Quantized", ax_tr, ax_bl), |
| ] |
|
|
| for name, ax_a, ax_b in line_specs: |
| is_active = name == metric_mode |
| color = "#C4473A" if is_active else "#AAB7C4" |
| linewidth = 4.0 if is_active else 1.8 |
| connection = ConnectionPatch( |
| xyA=(0.5, 0.5), |
| coordsA=ax_a.transAxes, |
| xyB=(0.5, 0.5), |
| coordsB=ax_b.transAxes, |
| color=color, |
| linewidth=linewidth, |
| zorder=0, |
| alpha=0.95 if is_active else 0.8, |
| ) |
| fig.add_artist(connection) |
|
|
| fig.tight_layout(pad=2.0) |
| return fig |
|
|
|
|
| def _to_base64_grayscale_png(image: np.ndarray) -> str: |
| image_uint8 = np.clip(np.asarray(image) * 255.0, 0, 255).astype(np.uint8) |
| pil_image = Image.fromarray(image_uint8, mode="L") |
| buffer = BytesIO() |
| pil_image.save(buffer, format="PNG") |
| return base64.b64encode(buffer.getvalue()).decode("ascii") |
|
|
|
|
| def create_visualization_svg_html( |
| original_display: np.ndarray, |
| windowed_display: np.ndarray, |
| quantized_display: np.ndarray, |
| quantized_windowed_display: np.ndarray, |
| bit_depth: int, |
| low: float, |
| high: float, |
| metric_mode: str, |
| ) -> str: |
| images = { |
| "original": _to_base64_grayscale_png(original_display), |
| "windowed": _to_base64_grayscale_png(windowed_display), |
| "quantized": _to_base64_grayscale_png(quantized_display), |
| "quantized_windowed": _to_base64_grayscale_png(quantized_windowed_display), |
| } |
|
|
| panels = { |
| "original": {"x": 70, "y": 40, "title": "Original Display", "q": "Quantization: None", "w": "Window: None"}, |
| "windowed": {"x": 560, "y": 40, "title": "Windowed Display", "q": "Quantization: None", "w": f"Window: [{low:.1f}, {high:.1f}]"}, |
| "quantized": {"x": 70, "y": 510, "title": "Quantized Display", "q": f"Quantization: bit depth = {bit_depth}", "w": "Window: None"}, |
| "quantized_windowed": {"x": 560, "y": 510, "title": "Quantized + Windowed Display", "q": f"Quantization: bit depth = {bit_depth}", "w": f"Window: [{low:.1f}, {high:.1f}]"}, |
| } |
| panel_w = 360 |
| panel_h = 360 |
| arrow_margin = 18 |
|
|
| comparisons = [ |
| ("Original vs Windowed", "original", "windowed"), |
| ("Original vs Quantized", "original", "quantized"), |
| ("Windowed vs Quantized + Windowed", "windowed", "quantized_windowed"), |
| ("Quantized vs Quantized + Windowed", "quantized", "quantized_windowed"), |
| ("Original vs Quantized + Windowed", "original", "quantized_windowed"), |
| ("Windowed vs Quantized", "windowed", "quantized"), |
| ] |
|
|
| def center(panel_key: str) -> tuple[float, float]: |
| p = panels[panel_key] |
| return p["x"] + panel_w / 2.0, p["y"] + panel_h / 2.0 |
|
|
| def edge_points(src: str, dst: str) -> tuple[float, float, float, float]: |
| src_panel = panels[src] |
| dst_panel = panels[dst] |
| src_cx, src_cy = center(src) |
| dst_cx, dst_cy = center(dst) |
|
|
| if abs(src_cy - dst_cy) < 1e-6: |
| if src_cx < dst_cx: |
| return src_panel["x"] + panel_w + arrow_margin, src_cy, dst_panel["x"] - arrow_margin, dst_cy |
| return src_panel["x"] - arrow_margin, src_cy, dst_panel["x"] + panel_w + arrow_margin, dst_cy |
|
|
| if abs(src_cx - dst_cx) < 1e-6: |
| if src_cy < dst_cy: |
| return src_cx, src_panel["y"] + panel_h + arrow_margin, dst_cx, dst_panel["y"] - arrow_margin |
| return src_cx, src_panel["y"] - arrow_margin, dst_cx, dst_panel["y"] + panel_h + arrow_margin |
|
|
| if src_cx < dst_cx and src_cy < dst_cy: |
| return ( |
| src_panel["x"] + panel_w + arrow_margin, |
| src_panel["y"] + panel_h + arrow_margin, |
| dst_panel["x"] - arrow_margin, |
| dst_panel["y"] - arrow_margin, |
| ) |
| if src_cx > dst_cx and src_cy < dst_cy: |
| return ( |
| src_panel["x"] - arrow_margin, |
| src_panel["y"] + panel_h + arrow_margin, |
| dst_panel["x"] + panel_w + arrow_margin, |
| dst_panel["y"] - arrow_margin, |
| ) |
| if src_cx < dst_cx and src_cy > dst_cy: |
| return ( |
| src_panel["x"] + panel_w + arrow_margin, |
| src_panel["y"] - arrow_margin, |
| dst_panel["x"] - arrow_margin, |
| dst_panel["y"] + panel_h + arrow_margin, |
| ) |
| return ( |
| src_panel["x"] - arrow_margin, |
| src_panel["y"] - arrow_margin, |
| dst_panel["x"] + panel_w + arrow_margin, |
| dst_panel["y"] + panel_h + arrow_margin, |
| ) |
|
|
| def arrow_group(name: str, src: str, dst: str) -> str: |
| x1, y1, x2, y2 = edge_points(src, dst) |
| active = name == metric_mode |
| stroke = "#C4473A" if active else "#B4BEC8" |
| width = "4" |
| opacity = "1.0" if active else "0.85" |
| marker_id = "arrowhead-active" if active else "arrowhead-inactive" |
| return f""" |
| <g data-compare="{name}" style="cursor:pointer;"> |
| <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{stroke}" stroke-width="18" stroke-opacity="0.01"/> |
| <line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{stroke}" stroke-width="{width}" stroke-linecap="round" marker-start="url(#{marker_id})" marker-end="url(#{marker_id})" opacity="{opacity}"/> |
| </g> |
| """ |
|
|
| panel_svg_parts = [] |
| for key, panel in panels.items(): |
| panel_svg_parts.append( |
| f""" |
| <rect x="{panel['x']-8}" y="{panel['y']-12}" width="{panel_w+16}" height="{panel_h+114}" rx="18" fill="#FFFFFF" stroke="#D7E0EA" stroke-width="2"/> |
| <image x="{panel['x']}" y="{panel['y']}" width="{panel_w}" height="{panel_h}" href="data:image/png;base64,{images[key]}" preserveAspectRatio="xMidYMid meet"/> |
| <text x="{panel['x']}" y="{panel['y'] + panel_h + 28}" font-size="17" font-weight="600" fill="#22313F">{panel['title']}</text> |
| <text x="{panel['x']}" y="{panel['y'] + panel_h + 54}" font-size="17" fill="#22313F">{panel['q']}</text> |
| <text x="{panel['x']}" y="{panel['y'] + panel_h + 80}" font-size="17" fill="#22313F">{panel['w']}</text> |
| """ |
| ) |
|
|
| arrow_svg_parts = [arrow_group(name, src, dst) for name, src, dst in comparisons] |
|
|
| return f""" |
| <div style=" |
| width: 100%; |
| max-width: 1000px; |
| aspect-ratio: 1000 / 972; |
| margin: 0 auto; |
| overflow: hidden; |
| line-height: 0; |
| "> |
| <svg |
| viewBox="0 0 1000 972" |
| style=" |
| width: 100%; |
| height: 100%; |
| display: block; |
| overflow: hidden; |
| shape-rendering: geometricPrecision; |
| text-rendering: geometricPrecision; |
| " |
| xmlns="http://www.w3.org/2000/svg" |
| xmlns:xlink="http://www.w3.org/1999/xlink" |
| > |
| <defs> |
| <marker id="arrowhead-active" markerWidth="12" markerHeight="12" refX="6" refY="6" orient="auto-start-reverse"> |
| <path d="M 0 0 L 12 6 L 0 12 z" fill="#C4473A"/> |
| </marker> |
| <marker id="arrowhead-inactive" markerWidth="12" markerHeight="12" refX="6" refY="6" orient="auto-start-reverse"> |
| <path d="M 0 0 L 12 6 L 0 12 z" fill="#B4BEC8"/> |
| </marker> |
| </defs> |
| {''.join(panel_svg_parts)} |
| {''.join(arrow_svg_parts)} |
| </svg> |
| </div> |
| """ |
|
|