Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import datetime as dt | |
| import json | |
| import shutil | |
| import subprocess | |
| import sys | |
| import threading | |
| import traceback | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| THIS_DIR = Path(__file__).resolve().parent | |
| RELEASE_ROOT = THIS_DIR.parent | |
| if str(RELEASE_ROOT) not in sys.path: | |
| sys.path.insert(0, str(RELEASE_ROOT)) | |
| from hf_weight_manager import resolve_weight_paths | |
| STEP1_DIR = RELEASE_ROOT / "Step1_VesselSeg" | |
| STEP2_DIR = RELEASE_ROOT / "Step2_DetectCrop" | |
| STEP3_DIR = RELEASE_ROOT / "Step3_Reg" | |
| DEFAULT_STEP3_CONFIG = STEP3_DIR / "Src" / "config" / "test.yaml" | |
| DEFAULT_STEP3_MODEL = STEP3_DIR / "Src" / "save" / "crop_vseg_vessel_111.pth" | |
| DEFAULT_RETRY_FILTERED_INLIER_THRESHOLD = 20 | |
| FEATURE_NAMES = ["1", "x", "y", "x^2", "x*y", "y^2"] | |
| DETECTION_SIZE = 800.0 | |
| COORD_SCALE = 1000.0 | |
| INVERTED_POLARITY_ROUTES = {"fallback_invert_no_crop"} | |
| SEG_MODEL_CARE = "CARe-VesselSeg" | |
| SEG_MODEL_UNET_DCP = "Broad domain retinal vessel segmentation [ICASSP 2025]" | |
| DEFAULT_DEVICE = "auto" | |
| _STEP2_API = None | |
| _STEP2_RUNTIME_CACHE: dict[str, dict[str, object]] = {} | |
| _RUNTIME_LOCK = threading.Lock() | |
| def get_weight_paths() -> dict[str, Path]: | |
| return resolve_weight_paths(RELEASE_ROOT) | |
| def resolve_device(device: str) -> str: | |
| normalized = str(device).strip().lower() | |
| if normalized == "auto": | |
| return "cuda:0" if torch.cuda.is_available() else "cpu" | |
| return str(device) | |
| def ensure_dir(path: Path) -> Path: | |
| path.mkdir(parents=True, exist_ok=True) | |
| return path | |
| def run_cmd(cmd: list[str], log_path: Path) -> None: | |
| ensure_dir(log_path.parent) | |
| with log_path.open("w", encoding="utf-8") as log_file: | |
| process = subprocess.run( | |
| cmd, | |
| stdout=log_file, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| check=False, | |
| ) | |
| if process.returncode != 0: | |
| raise RuntimeError(f"Command failed ({process.returncode}), see log: {log_path}") | |
| def infer_modality_from_path(image_path: Path) -> str: | |
| stem = image_path.stem.upper() | |
| if "OCTA" in stem or stem.endswith("_001"): | |
| return "OCTA" | |
| if "CFP" in stem or "FP" in stem or "FUNDUS" in stem: | |
| return "CFP" | |
| return "CFP" | |
| def resolve_modalities( | |
| query_modality_arg: str, | |
| refer_modality_arg: str, | |
| query_image_path: Path, | |
| refer_image_path: Path, | |
| ) -> tuple[str, str]: | |
| query_modality = query_modality_arg.upper() | |
| refer_modality = refer_modality_arg.upper() | |
| if query_modality == "AUTO": | |
| query_modality = infer_modality_from_path(query_image_path) | |
| if refer_modality == "AUTO": | |
| refer_modality = infer_modality_from_path(refer_image_path) | |
| return query_modality, refer_modality | |
| def run_step1_broad_with_explicit_modalities( | |
| *, | |
| query_image: Path, | |
| refer_image: Path, | |
| query_stem: str, | |
| refer_stem: str, | |
| query_modality: str, | |
| refer_modality: str, | |
| device: str, | |
| model_dir: Path | None, | |
| output_dir: Path, | |
| log_path: Path, | |
| ) -> tuple[Path, Path]: | |
| project_root = STEP1_DIR / "ThirdParty" / "VesselSegProject_BroadDomain" | |
| if not project_root.is_dir(): | |
| raise FileNotFoundError(f"BroadDomain project directory not found: {project_root}") | |
| if str(project_root) not in sys.path: | |
| sys.path.insert(0, str(project_root)) | |
| from BroadDomainRetinalVesselSeg.UNet_DCP import inference_unet_dcp | |
| from Code.VesselSeg.broad_domain_vessel_seg import BroadDomainVesselSegmenter | |
| resolved_device = resolve_device(device) | |
| ensure_dir(log_path.parent) | |
| vessel_dir = ensure_dir(output_dir / "vessels") | |
| if hasattr(inference_unet_dcp, "set_runtime_device"): | |
| runtime_device = inference_unet_dcp.set_runtime_device(resolved_device) | |
| else: | |
| runtime_device = torch.device(resolved_device) | |
| if hasattr(torch.backends, "mkldnn"): | |
| torch.backends.mkldnn.enabled = runtime_device.type != "cpu" | |
| inference_unet_dcp.device = runtime_device | |
| segmenter = BroadDomainVesselSegmenter( | |
| model_dir=None if model_dir is None else str(model_dir), | |
| split_block=(4, 4), | |
| ) | |
| query_bgr = read_as_bgr(query_image) | |
| refer_bgr = read_as_bgr(refer_image) | |
| query_vessel = segmenter.segment(query_bgr, modality=query_modality) | |
| refer_vessel = segmenter.segment(refer_bgr, modality=refer_modality) | |
| query_vessel_path = vessel_dir / f"{query_stem}_vessel.jpg" | |
| refer_vessel_path = vessel_dir / f"{refer_stem}_vessel.jpg" | |
| if not cv2.imwrite(str(query_vessel_path), query_vessel): | |
| raise RuntimeError(f"Failed to write output: {query_vessel_path}") | |
| if not cv2.imwrite(str(refer_vessel_path), refer_vessel): | |
| raise RuntimeError(f"Failed to write output: {refer_vessel_path}") | |
| summary = { | |
| "backend": "BroadDomainVesselSegmenter(explicit_modality_single_pair)", | |
| "query_image": str(query_image), | |
| "refer_image": str(refer_image), | |
| "query_modality": query_modality, | |
| "refer_modality": refer_modality, | |
| "device": str(runtime_device), | |
| "query_vessel_path": str(query_vessel_path), | |
| "refer_vessel_path": str(refer_vessel_path), | |
| } | |
| ensure_dir(output_dir) | |
| (output_dir / "summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8") | |
| with log_path.open("w", encoding="utf-8") as f: | |
| f.write(json.dumps(summary, indent=2, ensure_ascii=False)) | |
| f.write("\n") | |
| return query_vessel_path, refer_vessel_path | |
| def read_as_bgr(path: Path) -> np.ndarray: | |
| image = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) | |
| if image is None: | |
| raise FileNotFoundError(f"Unable to read image: {path}") | |
| if image.ndim == 2: | |
| return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| if image.shape[2] == 4: | |
| return cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) | |
| return image | |
| def normalize_to_png(src: Path, dst: Path) -> Path: | |
| image = read_as_bgr(src) | |
| ensure_dir(dst.parent) | |
| if not cv2.imwrite(str(dst), image): | |
| raise RuntimeError(f"Failed to write output: {dst}") | |
| return dst | |
| def resolve_vessel_path(vessel_dir: Path, stem: str) -> Path: | |
| candidates = [ | |
| f"{stem}_vessel.jpg", | |
| f"{stem}_vessel.png", | |
| f"{stem}_vessel.jpeg", | |
| f"{stem}.jpg", | |
| f"{stem}.png", | |
| f"{stem}.jpeg", | |
| ] | |
| for name in candidates: | |
| path = vessel_dir / name | |
| if path.exists(): | |
| return path | |
| raise FileNotFoundError(f"Vessel segmentation image not found: {vessel_dir} (stem={stem})") | |
| def compute_wfcfp_only_crop_box_1000(refer_txt_path: Path) -> np.ndarray: | |
| points = np.asarray(np.loadtxt(refer_txt_path), dtype=np.float64) | |
| points = np.atleast_2d(points) | |
| if points.shape[0] < 2 or points.shape[1] < 2: | |
| raise ValueError(f"Invalid wfCFP coordinate file format: {refer_txt_path}") | |
| od_y_norm, od_x_norm = float(points[0, 0]), float(points[0, 1]) | |
| fovea_y_norm, fovea_x_norm = float(points[1, 0]), float(points[1, 1]) | |
| od_x_800 = od_x_norm * DETECTION_SIZE | |
| fovea_x_800 = fovea_x_norm * DETECTION_SIZE | |
| fovea_y_800 = fovea_y_norm * DETECTION_SIZE | |
| width = abs(fovea_x_800 - od_x_800) | |
| if width <= 1e-6: | |
| raise ValueError(f"Cannot build wfCFP-only crop box, |fovea_x-od_x| is too small: {refer_txt_path}") | |
| rate = COORD_SCALE / DETECTION_SIZE | |
| crop_box = np.array( | |
| [ | |
| [int((fovea_y_800 - width) * rate), int((fovea_x_800 - width) * rate)], | |
| [int((fovea_y_800 + width) * rate), int((fovea_x_800 + width) * rate)], | |
| ], | |
| dtype=np.int32, | |
| ) | |
| crop_box = np.clip(crop_box, 0, int(COORD_SCALE)) | |
| if crop_box[1, 0] <= crop_box[0, 0] or crop_box[1, 1] <= crop_box[0, 1]: | |
| raise ValueError(f"Invalid wfCFP-only crop box: {crop_box.tolist()}") | |
| return crop_box | |
| def load_step2_api(): | |
| global _STEP2_API | |
| if _STEP2_API is not None: | |
| return _STEP2_API | |
| if str(STEP2_DIR) not in sys.path: | |
| sys.path.insert(0, str(STEP2_DIR)) | |
| import run_detect_crop as step2_mod | |
| weight_paths = get_weight_paths() | |
| _STEP2_API = { | |
| "module": step2_mod, | |
| "infer_device": step2_mod.infer_device, | |
| "build_detector": step2_mod.build_detector, | |
| "run_wfcfp_detection_and_crop": step2_mod.run_wfcfp_detection_and_crop, | |
| "run_octa_detection": step2_mod.run_octa_detection, | |
| "defaults": { | |
| "crop_output_size": step2_mod.DEFAULT_CROP_OUTPUT_SIZE, | |
| "nms_threshold": step2_mod.DEFAULT_NMS_THRESHOLD, | |
| "max_detection_attempts": step2_mod.DEFAULT_MAX_DETECTION_ATTEMPTS, | |
| "max_rgb_fallback_attempts": step2_mod.DEFAULT_MAX_RGB_FALLBACK_ATTEMPTS, | |
| "max_octa_detection_attempts": step2_mod.DEFAULT_MAX_OCTA_DETECTION_ATTEMPTS, | |
| "max_octa_plain_fallback_attempts": step2_mod.DEFAULT_MAX_OCTA_PLAIN_FALLBACK_ATTEMPTS, | |
| "octa_fovea_center_max_offset_norm": step2_mod.DEFAULT_OCTA_FOVEA_CENTER_MAX_OFFSET_NORM, | |
| }, | |
| "weights": { | |
| "wfcfp_fused": weight_paths["step2_wfcfp_fused"].resolve(), | |
| "wfcfp_rgb_fallback": weight_paths["step2_wfcfp_rgb_fallback"].resolve(), | |
| "octa_fused": weight_paths["step2_octa_fused"].resolve(), | |
| "octa_plain_fallback": weight_paths["step2_octa_plain_fallback"].resolve(), | |
| }, | |
| } | |
| return _STEP2_API | |
| def get_step2_runtime(device: str) -> dict[str, object]: | |
| api = load_step2_api() | |
| resolved_device = resolve_device(device) | |
| with _RUNTIME_LOCK: | |
| cached = _STEP2_RUNTIME_CACHE.get(resolved_device) | |
| if cached is not None: | |
| return cached | |
| torch_device = api["infer_device"](resolved_device) | |
| weights = api["weights"] | |
| for weight_path in weights.values(): | |
| if not Path(weight_path).is_file(): | |
| raise FileNotFoundError(f"Step2 weights file does not exist: {weight_path}") | |
| runtime = { | |
| "device": str(torch_device), | |
| "wfcfp_fused_detector": api["build_detector"](weights["wfcfp_fused"], torch_device), | |
| "wfcfp_rgb_fallback_detector": api["build_detector"](weights["wfcfp_rgb_fallback"], torch_device), | |
| "octa_fused_detector": api["build_detector"](weights["octa_fused"], torch_device), | |
| "octa_plain_fallback_detector": api["build_detector"](weights["octa_plain_fallback"], torch_device), | |
| "weights": weights, | |
| } | |
| _STEP2_RUNTIME_CACHE[resolved_device] = runtime | |
| return runtime | |
| def run_step2_inprocess( | |
| *, | |
| wfcfp_image: Path, | |
| wfcfp_vessel: Path, | |
| octa_image: Path, | |
| octa_vessel: Path, | |
| output_dir: Path, | |
| device: str, | |
| log_path: Path | None, | |
| save_debug_artifacts: bool = True, | |
| ) -> dict[str, object]: | |
| api = load_step2_api() | |
| runtime = get_step2_runtime(device) | |
| defaults = api["defaults"] | |
| output_dir = output_dir.resolve() | |
| ensure_dir(output_dir) | |
| wfcfp_summary = api["run_wfcfp_detection_and_crop"]( | |
| wfcfp_image_path=wfcfp_image, | |
| wfcfp_vessel_path=wfcfp_vessel, | |
| output_dir=output_dir, | |
| fused_detector=runtime["wfcfp_fused_detector"], | |
| rgb_fallback_detector=runtime["wfcfp_rgb_fallback_detector"], | |
| device=api["infer_device"](runtime["device"]), | |
| crop_output_size=defaults["crop_output_size"], | |
| nms_threshold=defaults["nms_threshold"], | |
| max_detection_attempts=defaults["max_detection_attempts"], | |
| max_rgb_fallback_attempts=defaults["max_rgb_fallback_attempts"], | |
| weights_path=runtime["weights"]["wfcfp_fused"], | |
| rgb_fallback_weights_path=runtime["weights"]["wfcfp_rgb_fallback"], | |
| save_debug_artifacts=save_debug_artifacts, | |
| summary_filename="wfcfp_summary.json", | |
| ) | |
| octa_summary = api["run_octa_detection"]( | |
| octa_image_path=octa_image, | |
| octa_vessel_path=octa_vessel, | |
| wfcfp_points_norm_yx=np.asarray(wfcfp_summary["points_norm_yx"], dtype=np.float32), | |
| output_dir=output_dir, | |
| fused_detector=runtime["octa_fused_detector"], | |
| plain_fallback_detector=runtime["octa_plain_fallback_detector"], | |
| device=api["infer_device"](runtime["device"]), | |
| nms_threshold=defaults["nms_threshold"], | |
| max_detection_attempts=defaults["max_octa_detection_attempts"], | |
| max_plain_fallback_attempts=defaults["max_octa_plain_fallback_attempts"], | |
| octa_fovea_center_max_offset_norm=defaults["octa_fovea_center_max_offset_norm"], | |
| weights_path=runtime["weights"]["octa_fused"], | |
| plain_fallback_weights_path=runtime["weights"]["octa_plain_fallback"], | |
| save_debug_artifacts=save_debug_artifacts, | |
| summary_filename="octa_summary.json", | |
| ) | |
| summary = dict(wfcfp_summary) | |
| summary["has_octa_detection"] = True | |
| summary["octa_summary_path"] = octa_summary["summary_path"] | |
| summary["octa"] = octa_summary | |
| summary_path = output_dir / "summary.json" | |
| summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| if log_path is not None: | |
| ensure_dir(log_path.parent) | |
| log_payload = { | |
| "status": "success", | |
| "mode": "inprocess_cached", | |
| "device": runtime["device"], | |
| "summary_path": str(summary_path), | |
| "weights": {k: str(v) for k, v in runtime["weights"].items()}, | |
| "save_debug_artifacts": bool(save_debug_artifacts), | |
| } | |
| log_path.write_text(json.dumps(log_payload, indent=2), encoding="utf-8") | |
| return summary | |
| def load_step3_api(): | |
| if str(STEP3_DIR) not in sys.path: | |
| sys.path.insert(0, str(STEP3_DIR)) | |
| from run_register_pair import fit_polynomial_from_pairs, load_pair_points_pixels, run_registration_case, warmup_predictor_cache | |
| return run_registration_case, fit_polynomial_from_pairs, load_pair_points_pixels, warmup_predictor_cache | |
| def preload_runtime_models( | |
| device: str = DEFAULT_DEVICE, | |
| config_path: Path = DEFAULT_STEP3_CONFIG, | |
| model_path: Path | None = DEFAULT_STEP3_MODEL, | |
| ) -> dict[str, object]: | |
| resolved_device = resolve_device(device) | |
| weight_paths = get_weight_paths() | |
| step2_runtime = get_step2_runtime(resolved_device) | |
| _, _, _, warmup_predictor_cache = load_step3_api() | |
| selected_model_path: Path | None | |
| if model_path is None: | |
| selected_model_path = weight_paths["step3_model"] | |
| else: | |
| selected_model_path = Path(model_path) | |
| if not selected_model_path.is_file(): | |
| try: | |
| is_default = selected_model_path.resolve() == DEFAULT_STEP3_MODEL.resolve() | |
| except Exception: | |
| is_default = False | |
| if is_default: | |
| selected_model_path = weight_paths["step3_model"] | |
| if selected_model_path is None or not selected_model_path.is_file(): | |
| raise FileNotFoundError(f"Step3 model not found: {selected_model_path}") | |
| step3_warmup = warmup_predictor_cache( | |
| str(config_path.resolve()), | |
| device=resolved_device, | |
| model_path=str(selected_model_path.resolve()), | |
| ) | |
| return { | |
| "device": resolved_device, | |
| "step2": { | |
| "device": step2_runtime["device"], | |
| "weights": {k: str(v) for k, v in step2_runtime["weights"].items()}, | |
| }, | |
| "step3": { | |
| **step3_warmup, | |
| "model_path": str(selected_model_path.resolve()), | |
| }, | |
| } | |
| def route_template(route_name: str, output_dir: Path) -> dict[str, object]: | |
| return { | |
| "route_name": route_name, | |
| "output_dir": str(output_dir.resolve()), | |
| "status": "failed", | |
| "error_message": "", | |
| "pair_count": 0, | |
| "inlier_count": 0, | |
| "ransac_inlier_count": 0, | |
| "register_summary": None, | |
| "fit_eval": None, | |
| "polarity_inverted": route_name in INVERTED_POLARITY_ROUTES, | |
| } | |
| def run_single_route( | |
| *, | |
| route_name: str, | |
| output_dir: Path, | |
| run_registration_case, | |
| fit_polynomial_from_pairs, | |
| query_image: Path, | |
| query_vessel: Path, | |
| refer_image: Path, | |
| refer_vessel: Path, | |
| crop_box_1000: Path | None, | |
| config_path: Path, | |
| model_path: Path | None, | |
| device: str, | |
| inlier_method: str, | |
| legacy_script_crop: bool, | |
| save_debug_artifacts: bool, | |
| force_rerun: bool, | |
| ) -> dict[str, object]: | |
| result = route_template(route_name, output_dir) | |
| try: | |
| if force_rerun and output_dir.exists(): | |
| shutil.rmtree(output_dir) | |
| ensure_dir(output_dir.parent) | |
| summary = run_registration_case( | |
| query_image=str(query_image), | |
| query_vessel=str(query_vessel), | |
| refer_image=str(refer_image), | |
| refer_vessel=str(refer_vessel), | |
| crop_box_1000=None if crop_box_1000 is None else str(crop_box_1000), | |
| query_crop_box_1000=None, | |
| output_dir=str(output_dir), | |
| device=device, | |
| config_path=str(config_path), | |
| model_path=None if model_path is None else str(model_path), | |
| crop_mode="refer_box_only", | |
| inlier_method=inlier_method, | |
| legacy_script_crop=legacy_script_crop, | |
| save_debug_artifacts=save_debug_artifacts, | |
| ) | |
| result["register_summary"] = summary | |
| if summary.get("status") != "success": | |
| raise RuntimeError(summary.get("error_message") or "Step3 failed") | |
| pair_txt_path = summary["pair_txt_path"] | |
| fit_eval = fit_polynomial_from_pairs(pair_txt_path, inlier_method=inlier_method) | |
| fit_ransac = fit_polynomial_from_pairs(pair_txt_path, inlier_method="RANSAC") | |
| result["status"] = "success" | |
| result["pair_count"] = int(fit_eval["pair_count"]) | |
| result["inlier_count"] = int(fit_eval["inlier_count"]) | |
| result["ransac_inlier_count"] = int(fit_ransac["inlier_count"]) | |
| result["fit_eval"] = fit_eval | |
| except Exception as exc: # pragma: no cover | |
| result["error_message"] = str(exc) | |
| return result | |
| def select_best_route(route_results: list[dict[str, object]]) -> dict[str, object] | None: | |
| success = [item for item in route_results if item["status"] == "success"] | |
| if not success: | |
| return None | |
| return max(success, key=lambda item: int(item["inlier_count"])) | |
| def copy_file(src: Path, dst: Path) -> Path: | |
| ensure_dir(dst.parent) | |
| shutil.copyfile(src, dst) | |
| return dst | |
| def as_path(value: str | None) -> Path | None: | |
| if value is None: | |
| return None | |
| return Path(value) | |
| def dump_polynomial_params(fit_eval: dict[str, object], save_path: Path) -> dict[str, object]: | |
| regr_x = fit_eval["regr_x"] | |
| regr_y = fit_eval["regr_y"] | |
| coef_x = [float(v) for v in np.asarray(regr_x.coef_).ravel()] | |
| coef_y = [float(v) for v in np.asarray(regr_y.coef_).ravel()] | |
| intercept_x = float(regr_x.intercept_) | |
| intercept_y = float(regr_y.intercept_) | |
| effective_x = [coef_x[0] + intercept_x, coef_x[1], coef_x[2], coef_x[3], coef_x[4], coef_x[5]] | |
| effective_y = [coef_y[0] + intercept_y, coef_y[1], coef_y[2], coef_y[3], coef_y[4], coef_y[5]] | |
| payload = { | |
| "feature_order": FEATURE_NAMES, | |
| "raw_linear_regression": { | |
| "x_model": {"intercept": intercept_x, "coef": coef_x}, | |
| "y_model": {"intercept": intercept_y, "coef": coef_y}, | |
| "formula": "pred = intercept + sum_i coef[i] * phi_i, phi=[1,x,y,x^2,xy,y^2]", | |
| }, | |
| "effective_quadratic_form": { | |
| "x_pred": { | |
| "a0": effective_x[0], | |
| "a1_x": effective_x[1], | |
| "a2_y": effective_x[2], | |
| "a3_x2": effective_x[3], | |
| "a4_xy": effective_x[4], | |
| "a5_y2": effective_x[5], | |
| }, | |
| "y_pred": { | |
| "b0": effective_y[0], | |
| "b1_x": effective_y[1], | |
| "b2_y": effective_y[2], | |
| "b3_x2": effective_y[3], | |
| "b4_xy": effective_y[4], | |
| "b5_y2": effective_y[5], | |
| }, | |
| "formula": "x' = a0+a1*x+a2*y+a3*x^2+a4*x*y+a5*y^2; y' = b0+b1*x+b2*y+b3*x^2+b4*x*y+b5*y^2", | |
| }, | |
| } | |
| ensure_dir(save_path.parent) | |
| save_path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") | |
| return payload | |
| def relpath(path: Path, root: Path) -> str: | |
| try: | |
| return str(path.resolve().relative_to(root.resolve())) | |
| except Exception: | |
| return str(path.resolve()) | |
| def add_label(image: np.ndarray, label: str) -> np.ndarray: | |
| canvas = image.copy() | |
| cv2.rectangle(canvas, (0, 0), (canvas.shape[1], 40), (20, 20, 20), -1) | |
| cv2.putText( | |
| canvas, | |
| label, | |
| (12, 28), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.8, | |
| (255, 255, 255), | |
| 2, | |
| lineType=cv2.LINE_AA, | |
| ) | |
| return canvas | |
| def resize_keep(image: np.ndarray, size: tuple[int, int]) -> np.ndarray: | |
| return cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) | |
| def create_visual_summary( | |
| *, | |
| save_path: Path, | |
| octa_raw: Path, | |
| wfcfp_raw: Path, | |
| octa_vessel: Path, | |
| wfcfp_vessel: Path, | |
| keypoint_vis: Path, | |
| all_keypoint_vis: Path, | |
| vessel_overlay: Path, | |
| raw_overlay: Path, | |
| ) -> Path: | |
| tile_w, tile_h = 700, 700 | |
| panels = [ | |
| (read_as_bgr(octa_raw), "OCTA Raw"), | |
| (read_as_bgr(wfcfp_raw), "wfCFP Raw"), | |
| (read_as_bgr(octa_vessel), "OCTA Vessel"), | |
| (read_as_bgr(wfcfp_vessel), "wfCFP Vessel"), | |
| (read_as_bgr(keypoint_vis), "Filtered Keypoints"), | |
| (read_as_bgr(vessel_overlay), "Vessel Overlay"), | |
| (read_as_bgr(raw_overlay), "Raw Overlay"), | |
| (read_as_bgr(all_keypoint_vis), "All Keypoints"), | |
| ] | |
| prepared = [] | |
| for image, label in panels[:8]: | |
| tile = resize_keep(image, (tile_w, tile_h)) | |
| tile = add_label(tile, label) if label else tile | |
| prepared.append(tile) | |
| row1 = np.hstack(prepared[0:4]) | |
| row2 = np.hstack(prepared[4:8]) | |
| board = np.vstack([row1, row2]) | |
| ensure_dir(save_path.parent) | |
| cv2.imwrite(str(save_path), board) | |
| return save_path | |
| def to_black_bg_white_vessel(gray_or_bgr: np.ndarray) -> np.ndarray: | |
| if gray_or_bgr.ndim == 3 and gray_or_bgr.shape[2] == 4: | |
| gray = cv2.cvtColor(gray_or_bgr, cv2.COLOR_BGRA2GRAY) | |
| elif gray_or_bgr.ndim == 3: | |
| gray = cv2.cvtColor(gray_or_bgr, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = gray_or_bgr | |
| # If background is bright, invert once to keep vessel as white on black. | |
| if float(np.mean(gray)) > 127.0: | |
| gray = 255 - gray | |
| return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) | |
| def draw_keypoint_pairs_on_black_vessel( | |
| *, | |
| query_img_path: Path, | |
| refer_img_path: Path, | |
| pair_points: np.ndarray, | |
| save_path: Path, | |
| title: str, | |
| restore_from_inverted: bool = False, | |
| coord_scale: float = COORD_SCALE, | |
| ) -> Path: | |
| query_src = cv2.imread(str(query_img_path), cv2.IMREAD_UNCHANGED) | |
| if query_src is None: | |
| raise FileNotFoundError(f"Unable to read keypoint-visualization background image: {query_img_path}") | |
| refer_src = cv2.imread(str(refer_img_path), cv2.IMREAD_UNCHANGED) | |
| if refer_src is None: | |
| raise FileNotFoundError(f"Unable to read keypoint-visualization background image: {refer_img_path}") | |
| if restore_from_inverted: | |
| query_src = cv2.bitwise_not(query_src) | |
| refer_src = cv2.bitwise_not(refer_src) | |
| query_bgr = to_black_bg_white_vessel(query_src) | |
| refer_bgr = to_black_bg_white_vessel(refer_src) | |
| # Normalize both panels to an identical shape, so left/right visualization | |
| # is consistent before drawing pair connections. | |
| h_a0, w_a0 = query_bgr.shape[:2] | |
| h_b0, w_b0 = refer_bgr.shape[:2] | |
| target_h = max(h_a0, h_b0) | |
| target_w = max(w_a0, w_b0) | |
| if h_a0 != target_h or w_a0 != target_w: | |
| query_bgr = cv2.resize(query_bgr, (target_w, target_h), interpolation=cv2.INTER_LINEAR) | |
| if h_b0 != target_h or w_b0 != target_w: | |
| refer_bgr = cv2.resize(refer_bgr, (target_w, target_h), interpolation=cv2.INTER_LINEAR) | |
| h_a, w_a = query_bgr.shape[:2] | |
| h_b, w_b = refer_bgr.shape[:2] | |
| canvas = np.zeros((max(h_a, h_b), w_a + w_b, 3), dtype=np.uint8) | |
| canvas[0:h_a, 0:w_a] = query_bgr | |
| canvas[0:h_b, w_a : w_a + w_b] = refer_bgr | |
| for query_pt, refer_pt in zip(pair_points[:, :2], pair_points[:, 2:]): | |
| query_x = float(query_pt[0]) * float(w_a) / float(coord_scale) | |
| query_y = float(query_pt[1]) * float(h_a) / float(coord_scale) | |
| refer_x = float(refer_pt[0]) * float(w_b) / float(coord_scale) | |
| refer_y = float(refer_pt[1]) * float(h_b) / float(coord_scale) | |
| pt_a = (int(round(query_x)), int(round(query_y))) | |
| pt_b = (int(round(refer_x + w_a)), int(round(refer_y))) | |
| cv2.line(canvas, pt_a, pt_b, (0, 255, 0), 1, lineType=cv2.LINE_AA) | |
| cv2.circle(canvas, pt_a, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA) | |
| cv2.circle(canvas, pt_b, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA) | |
| title_height = 70 | |
| vis = np.zeros((canvas.shape[0] + title_height, canvas.shape[1], 3), dtype=np.uint8) | |
| vis[title_height:, :] = canvas | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 1.2 | |
| thickness = 2 | |
| text_size, _ = cv2.getTextSize(title, font, font_scale, thickness) | |
| text_x = max(0, (vis.shape[1] - text_size[0]) // 2) | |
| cv2.putText( | |
| vis, | |
| title, | |
| (text_x, 45), | |
| font, | |
| font_scale, | |
| (255, 255, 255), | |
| thickness, | |
| lineType=cv2.LINE_AA, | |
| ) | |
| ensure_dir(save_path.parent) | |
| cv2.imwrite(str(save_path), vis) | |
| return save_path | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Single-pair full inference pipeline (Step1+Step2+Step3)") | |
| parser.add_argument("--octa_image", required=True, type=Path, help="Path to OCTA image") | |
| parser.add_argument("--wfcfp_image", required=True, type=Path, help="Path to wfCFP image") | |
| parser.add_argument( | |
| "--output_dir", | |
| type=Path, | |
| default=THIS_DIR / "Output" / f"single_pair_{dt.datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
| help="Output directory", | |
| ) | |
| parser.add_argument( | |
| "--seg_model", | |
| choices=[SEG_MODEL_CARE, SEG_MODEL_UNET_DCP], | |
| default=SEG_MODEL_CARE, | |
| help=f"Step1 vessel segmentation model: {SEG_MODEL_CARE} or {SEG_MODEL_UNET_DCP}", | |
| ) | |
| parser.add_argument( | |
| "--inlier_method", | |
| choices=["RANSAC", "LMEDS"], | |
| default="LMEDS", | |
| help="Step3 inlier filtering/fitting method", | |
| ) | |
| parser.add_argument( | |
| "--retry_filtered_inlier_threshold", | |
| type=int, | |
| default=DEFAULT_RETRY_FILTERED_INLIER_THRESHOLD, | |
| help="Trigger retry when primary filtered inlier count is below this threshold", | |
| ) | |
| parser.add_argument( | |
| "--registration_mode", | |
| choices=["interface_b_retry", "direct_no_crop"], | |
| default="interface_b_retry", | |
| help=( | |
| "interface_b_retry: Interface-B logic (includes Step2 and retry routes); " | |
| "direct_no_crop: Skip Step2 and cropping; run Step3 directly after Step1 segmentation." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--query_modality", | |
| choices=["AUTO", "CFP", "OCTA"], | |
| default="AUTO", | |
| help=f"Query modality used by {SEG_MODEL_UNET_DCP}. AUTO infers by filename (contains OCTA/_001 -> OCTA, otherwise CFP).", | |
| ) | |
| parser.add_argument( | |
| "--refer_modality", | |
| choices=["AUTO", "CFP", "OCTA"], | |
| default="AUTO", | |
| help=f"Refer modality used by {SEG_MODEL_UNET_DCP}. AUTO infers by filename (default CFP).", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=DEFAULT_DEVICE, | |
| help=f"Device: auto/cpu/cuda:0. Default auto (cuda:0 if available, else cpu). Used by Step1({SEG_MODEL_CARE}), Step2, and Step3.", | |
| ) | |
| parser.add_argument("--config_path", type=Path, default=DEFAULT_STEP3_CONFIG, help="Path to Step3 config file") | |
| parser.add_argument("--model_path", type=Path, default=DEFAULT_STEP3_MODEL, help="Path to Step3 model weights") | |
| parser.add_argument("--force_rerun", action="store_true", help="Delete output directory first if it already exists") | |
| parser.add_argument("--continue_from_step1", action="store_true", help="Continue from current-sample Step1 outputs and run only Step2+Step3") | |
| parser.add_argument("--lite_outputs", action="store_true", help="Generate only core outputs and skip non-essential heavy visualizations") | |
| parser.add_argument("--keep_attempt_dirs", action="store_true", help="Keep intermediate Step3 route attempt directories") | |
| parser.add_argument( | |
| "--disable_nonessential_io", | |
| action="store_true", | |
| help="Disable non-essential debug artifact generation/writing in Step2/Step3 and pipeline copies.", | |
| ) | |
| return parser | |
| def main() -> None: | |
| args = build_parser().parse_args() | |
| resolved_device = resolve_device(args.device) | |
| weight_paths = get_weight_paths() | |
| octa_image = args.octa_image.resolve() | |
| wfcfp_image = args.wfcfp_image.resolve() | |
| output_dir = args.output_dir.resolve() | |
| config_path = args.config_path.resolve() | |
| model_path = args.model_path.resolve() | |
| if not octa_image.is_file(): | |
| raise FileNotFoundError(f"OCTA image not found: {octa_image}") | |
| if not wfcfp_image.is_file(): | |
| raise FileNotFoundError(f"wfCFP image not found: {wfcfp_image}") | |
| if not config_path.is_file(): | |
| raise FileNotFoundError(f"Step3 config not found: {config_path}") | |
| if not model_path.is_file(): | |
| try: | |
| uses_default_model = model_path == DEFAULT_STEP3_MODEL.resolve() | |
| except Exception: | |
| uses_default_model = False | |
| if uses_default_model: | |
| model_path = weight_paths["step3_model"].resolve() | |
| if not model_path.is_file(): | |
| raise FileNotFoundError( | |
| f"Step3 model not found: {model_path}. " | |
| "Set CARE_WEIGHTS_REPO_ID for snapshot-based weight loading if local files are absent." | |
| ) | |
| if output_dir.exists() and args.force_rerun: | |
| shutil.rmtree(output_dir) | |
| ensure_dir(output_dir) | |
| logs_dir = ensure_dir(output_dir / "logs") | |
| work_dir = ensure_dir(output_dir / "work") | |
| input_dir = ensure_dir(work_dir / "input_images") | |
| step1_out = ensure_dir(work_dir / "step1") | |
| step2_out = ensure_dir(work_dir / "step2") | |
| step3_attempts = ensure_dir(work_dir / "step3" / "attempts") | |
| step3_assets = ensure_dir(work_dir / "step3" / "assets") | |
| step3_selected_dir = work_dir / "step3" / "selected" | |
| results_dir = ensure_dir(output_dir / "results") | |
| results_vessels = ensure_dir(results_dir / "vessels") | |
| results_keypoints = ensure_dir(results_dir / "keypoints") | |
| results_overlays = ensure_dir(results_dir / "overlays") | |
| results_poly = ensure_dir(results_dir / "polynomial") | |
| results_vis = ensure_dir(results_dir / "visualizations") | |
| preload_runtime_models(device=resolved_device, config_path=config_path, model_path=model_path) | |
| run_registration_case, fit_polynomial_from_pairs, load_pair_points_pixels, _ = load_step3_api() | |
| query_modality, refer_modality = resolve_modalities( | |
| query_modality_arg=args.query_modality, | |
| refer_modality_arg=args.refer_modality, | |
| query_image_path=octa_image, | |
| refer_image_path=wfcfp_image, | |
| ) | |
| requested_registration_mode = args.registration_mode | |
| effective_registration_mode = requested_registration_mode | |
| registration_mode_auto_forced = False | |
| registration_mode_auto_reason = "" | |
| if requested_registration_mode == "interface_b_retry" and query_modality == "CFP" and refer_modality == "CFP": | |
| effective_registration_mode = "direct_no_crop" | |
| registration_mode_auto_forced = True | |
| registration_mode_auto_reason = "both_images_are_cfp" | |
| octa_png = normalize_to_png(octa_image, input_dir / "octa.png") | |
| wfcfp_png = normalize_to_png(wfcfp_image, input_dir / "wfcfp.png") | |
| step1_reused_in_current_run = False | |
| if args.continue_from_step1: | |
| vessel_dir = step1_out / "vessels" | |
| octa_vessel = resolve_vessel_path(vessel_dir, "octa") | |
| wfcfp_vessel = resolve_vessel_path(vessel_dir, "wfcfp") | |
| step1_reused_in_current_run = True | |
| (logs_dir / "step1.log").write_text("Continue from current-run Step1 outputs\n", encoding="utf-8") | |
| elif args.seg_model == SEG_MODEL_CARE: | |
| step1_script = STEP1_DIR / "run_vessel_seg.py" | |
| step1_cmd = [ | |
| sys.executable, | |
| str(step1_script), | |
| "--input_dir", | |
| str(input_dir), | |
| "--output_dir", | |
| str(step1_out), | |
| "--weights", | |
| str(weight_paths["step1_care"].resolve()), | |
| "--device", | |
| resolved_device, | |
| ] | |
| run_cmd(step1_cmd, logs_dir / "step1.log") | |
| vessel_dir = step1_out / "vessels" | |
| octa_vessel = resolve_vessel_path(vessel_dir, "octa") | |
| wfcfp_vessel = resolve_vessel_path(vessel_dir, "wfcfp") | |
| else: | |
| octa_vessel, wfcfp_vessel = run_step1_broad_with_explicit_modalities( | |
| query_image=octa_png, | |
| refer_image=wfcfp_png, | |
| query_stem="octa", | |
| refer_stem="wfcfp", | |
| query_modality=query_modality, | |
| refer_modality=refer_modality, | |
| device=resolved_device, | |
| model_dir=weight_paths["step1_broad_dir"].resolve(), | |
| output_dir=step1_out, | |
| log_path=logs_dir / "step1.log", | |
| ) | |
| step2_log_path = logs_dir / "step2.log" | |
| wfcfp_txt: Path | None = None | |
| if effective_registration_mode == "interface_b_retry": | |
| run_step2_inprocess( | |
| wfcfp_image=wfcfp_png, | |
| wfcfp_vessel=wfcfp_vessel, | |
| octa_image=octa_png, | |
| octa_vessel=octa_vessel, | |
| output_dir=step2_out, | |
| device=resolved_device, | |
| log_path=None if args.disable_nonessential_io else step2_log_path, | |
| save_debug_artifacts=not args.disable_nonessential_io, | |
| ) | |
| wfcfp_txt = step2_out / "wfcfp_od_fovea.txt" | |
| octa_txt = step2_out / "octa_od_fovea.txt" | |
| if not wfcfp_txt.is_file(): | |
| raise FileNotFoundError(f"Missing Step2 output: {wfcfp_txt}") | |
| if not octa_txt.is_file(): | |
| raise FileNotFoundError(f"Missing Step2 output: {octa_txt}") | |
| legacy_octa_txt = octa_vessel.with_name("octa.txt") | |
| legacy_wfcfp_txt = wfcfp_vessel.with_name("wfcfp.txt") | |
| copy_file(octa_txt, legacy_octa_txt) | |
| copy_file(wfcfp_txt, legacy_wfcfp_txt) | |
| common_kwargs = { | |
| "run_registration_case": run_registration_case, | |
| "fit_polynomial_from_pairs": fit_polynomial_from_pairs, | |
| "query_image": octa_png, | |
| "query_vessel": octa_vessel, | |
| "refer_image": wfcfp_png, | |
| "refer_vessel": wfcfp_vessel, | |
| "config_path": config_path, | |
| "model_path": model_path, | |
| "device": resolved_device, | |
| "inlier_method": args.inlier_method, | |
| "save_debug_artifacts": not args.disable_nonessential_io, | |
| "force_rerun": True, | |
| } | |
| route_results: list[dict[str, object]] = [] | |
| retry_reasons: list[str] = [] | |
| retry_triggered = False | |
| if effective_registration_mode == "direct_no_crop": | |
| full_box_path = step3_assets / "full_1000_box_1000.txt" | |
| np.savetxt(full_box_path, np.array([[0, 0], [1000, 1000]], dtype=np.int32), fmt="%d") | |
| direct_result = run_single_route( | |
| route_name="direct_no_crop", | |
| output_dir=step3_attempts / "direct_no_crop", | |
| crop_box_1000=full_box_path, | |
| legacy_script_crop=False, | |
| **common_kwargs, | |
| ) | |
| route_results.append(direct_result) | |
| else: | |
| primary = run_single_route( | |
| route_name="primary", | |
| output_dir=step3_attempts / "primary", | |
| crop_box_1000=None, | |
| legacy_script_crop=True, | |
| **common_kwargs, | |
| ) | |
| route_results.append(primary) | |
| if primary["status"] != "success": | |
| retry_reasons.append("primary_failed") | |
| else: | |
| if int(primary["inlier_count"]) < int(args.retry_filtered_inlier_threshold): | |
| retry_reasons.append(f"filtered_inlier_below_{int(args.retry_filtered_inlier_threshold)}") | |
| retry_triggered = len(retry_reasons) > 0 | |
| if retry_triggered: | |
| if wfcfp_txt is None: | |
| raise RuntimeError("wfcfp_txt must not be None in interface_b_retry mode") | |
| wfcfp_only_box = compute_wfcfp_only_crop_box_1000(wfcfp_txt) | |
| wfcfp_only_box_path = step3_assets / "wfcfp_only_crop_box_1000.txt" | |
| np.savetxt(wfcfp_only_box_path, wfcfp_only_box, fmt="%d") | |
| full_box_path = step3_assets / "full_1000_box_1000.txt" | |
| np.savetxt(full_box_path, np.array([[0, 0], [1000, 1000]], dtype=np.int32), fmt="%d") | |
| route_results.append( | |
| run_single_route( | |
| route_name="fallback_wfcfp_only_crop", | |
| output_dir=step3_attempts / "fallback_wfcfp_only_crop", | |
| crop_box_1000=wfcfp_only_box_path, | |
| legacy_script_crop=False, | |
| **common_kwargs, | |
| ) | |
| ) | |
| route_results.append( | |
| run_single_route( | |
| route_name="fallback_invert_no_crop", | |
| output_dir=step3_attempts / "fallback_invert_no_crop", | |
| crop_box_1000=full_box_path, | |
| legacy_script_crop=False, | |
| **common_kwargs, | |
| ) | |
| ) | |
| selected = select_best_route(route_results) | |
| if selected is None: | |
| messages = [f"{item['route_name']}: {item.get('error_message')}" for item in route_results] | |
| raise RuntimeError("All routes failed: " + " | ".join(messages)) | |
| selected_route_dir = Path(str(selected["output_dir"])) | |
| if not args.disable_nonessential_io: | |
| if step3_selected_dir.exists(): | |
| shutil.rmtree(step3_selected_dir) | |
| shutil.copytree(selected_route_dir, step3_selected_dir) | |
| selected_register_summary = selected["register_summary"] | |
| selected_fit_eval = selected["fit_eval"] | |
| selected_route_polarity_inverted = bool(selected.get("polarity_inverted", False)) | |
| pair_input_path = as_path(selected_register_summary.get("pair_input_space_pixels_path")) | |
| pair_restored_path = as_path(selected_register_summary.get("pair_restored_space_pixels_path")) | |
| if pair_input_path is None or pair_restored_path is None: | |
| raise RuntimeError("Selected route summary is missing keypoint pixel paths") | |
| pair_input = load_pair_points_pixels(str(pair_input_path)) | |
| pair_restored = load_pair_points_pixels(str(pair_restored_path)) | |
| inlier_mask = np.asarray(selected_fit_eval["inlier_mask"], dtype=bool) | |
| if len(pair_input) != len(inlier_mask) or len(pair_restored) != len(inlier_mask): | |
| raise RuntimeError("Keypoint pair count does not match inlier-mask length") | |
| filtered_input = pair_input[inlier_mask] | |
| filtered_restored = pair_restored[inlier_mask] | |
| filtered_norm = np.asarray(selected_fit_eval["inlier_pairs_norm"], dtype=np.float64) | |
| filtered_norm_path = results_keypoints / "filtered_pairs_norm.txt" | |
| filtered_input_path = results_keypoints / "filtered_pairs_input_space_pixels.txt" | |
| filtered_restored_path = results_keypoints / "filtered_pairs_restored_space_pixels.txt" | |
| np.savetxt(filtered_norm_path, filtered_norm, fmt="%.12f") | |
| np.savetxt(filtered_input_path, filtered_input, fmt="%.6f") | |
| np.savetxt(filtered_restored_path, filtered_restored, fmt="%.6f") | |
| octa_vessel_out = copy_file(octa_vessel, results_vessels / "octa_vessel.png") | |
| wfcfp_vessel_out = copy_file(wfcfp_vessel, results_vessels / "wfcfp_vessel.png") | |
| vessel_overlay_src = Path(selected_register_summary["vessel_merged_display_path"]) | |
| raw_overlay_src = Path(selected_register_summary["actual_overlay_path"]) | |
| vessel_overlay_out = copy_file(vessel_overlay_src, results_overlays / "vessel_overlay.png") | |
| raw_overlay_out = copy_file(raw_overlay_src, results_overlays / "raw_overlay.png") | |
| warped_octa_raw_out = None | |
| warped_octa_vessel_out = None | |
| if not args.lite_outputs: | |
| warped_octa_raw_out = copy_file(Path(selected_register_summary["warped_actual_path"]), results_overlays / "warped_octa_raw.png") | |
| warped_octa_vessel_out = copy_file(Path(selected_register_summary["warped_vessel_path"]), results_overlays / "warped_octa_vessel.png") | |
| debug_dir = Path(selected_register_summary["registration_debug_dir"]) | |
| query_restored = debug_dir / "match_query_restored.png" | |
| refer_restored = debug_dir / "match_refer_restored.png" | |
| query_input = debug_dir / "match_query_input.png" | |
| refer_input = debug_dir / "match_refer_input.png" | |
| if not query_restored.is_file(): | |
| query_restored = octa_vessel | |
| if not refer_restored.is_file(): | |
| refer_restored = wfcfp_vessel | |
| if not query_input.is_file(): | |
| query_input = octa_vessel | |
| if not refer_input.is_file(): | |
| refer_input = wfcfp_vessel | |
| kp_vis_input_out = None | |
| if not args.lite_outputs: | |
| kp_vis_input_out = draw_keypoint_pairs_on_black_vessel( | |
| query_img_path=query_input, | |
| refer_img_path=refer_input, | |
| pair_points=filtered_input, | |
| save_path=results_keypoints / "filtered_pairs_input_space_vis.png", | |
| title=f"{args.inlier_method} Inlier Match (Input Space), #inlier: {len(filtered_input)}", | |
| restore_from_inverted=selected_route_polarity_inverted, | |
| ) | |
| kp_vis_restored_out = draw_keypoint_pairs_on_black_vessel( | |
| query_img_path=query_restored, | |
| refer_img_path=refer_restored, | |
| pair_points=filtered_restored, | |
| save_path=results_keypoints / "filtered_pairs_restored_space_vis.png", | |
| title=f"{args.inlier_method} Inlier Match (Restored Space), #inlier: {len(filtered_restored)}", | |
| restore_from_inverted=selected_route_polarity_inverted, | |
| ) | |
| kp_vis_all_out = None | |
| if not args.lite_outputs: | |
| kp_vis_all_out = draw_keypoint_pairs_on_black_vessel( | |
| query_img_path=query_restored, | |
| refer_img_path=refer_restored, | |
| pair_points=pair_restored, | |
| save_path=results_keypoints / "all_pairs_vis.png", | |
| title=f"All Match (Restored Space), #pair: {len(pair_restored)}", | |
| restore_from_inverted=selected_route_polarity_inverted, | |
| ) | |
| poly_payload = dump_polynomial_params(selected_fit_eval, results_poly / "quadratic_polynomial_params.json") | |
| octa_raw_out = copy_file(octa_png, results_dir / "octa_raw.png") | |
| wfcfp_raw_out = copy_file(wfcfp_png, results_dir / "wfcfp_raw.png") | |
| visual_summary_path = None | |
| if not args.lite_outputs: | |
| visual_summary_path = create_visual_summary( | |
| save_path=results_vis / "summary_board.png", | |
| octa_raw=octa_raw_out, | |
| wfcfp_raw=wfcfp_raw_out, | |
| octa_vessel=octa_vessel_out, | |
| wfcfp_vessel=wfcfp_vessel_out, | |
| keypoint_vis=kp_vis_restored_out, | |
| all_keypoint_vis=kp_vis_all_out, | |
| vessel_overlay=vessel_overlay_out, | |
| raw_overlay=raw_overlay_out, | |
| ) | |
| route_records = [] | |
| for item in route_results: | |
| route_records.append( | |
| { | |
| "route_name": item["route_name"], | |
| "status": item["status"], | |
| "error_message": item.get("error_message", ""), | |
| "pair_count": int(item.get("pair_count", 0)), | |
| "inlier_count": int(item.get("inlier_count", 0)), | |
| "ransac_inlier_count": int(item.get("ransac_inlier_count", 0)), | |
| "polarity_inverted": bool(item.get("polarity_inverted", False)), | |
| "output_dir": relpath(Path(str(item["output_dir"])), output_dir), | |
| } | |
| ) | |
| summary = { | |
| "status": "success", | |
| "run_time": dt.datetime.now().isoformat(), | |
| "input": { | |
| "octa_image": str(octa_image), | |
| "wfcfp_image": str(wfcfp_image), | |
| "seg_model": args.seg_model, | |
| "continue_from_step1": bool(args.continue_from_step1), | |
| "step1_reused_in_current_run": bool(step1_reused_in_current_run), | |
| "lite_outputs": bool(args.lite_outputs), | |
| "disable_nonessential_io": bool(args.disable_nonessential_io), | |
| "registration_mode_requested": requested_registration_mode, | |
| "registration_mode_effective": effective_registration_mode, | |
| "registration_mode_auto_forced": registration_mode_auto_forced, | |
| "registration_mode_auto_reason": registration_mode_auto_reason, | |
| "query_modality": query_modality, | |
| "refer_modality": refer_modality, | |
| "inlier_method": args.inlier_method, | |
| "retry_filtered_inlier_threshold": int(args.retry_filtered_inlier_threshold), | |
| "device": resolved_device, | |
| "step3_config_path": str(config_path), | |
| "step3_model_path": str(model_path), | |
| }, | |
| "retry": { | |
| "retry_triggered": retry_triggered, | |
| "retry_reasons": retry_reasons, | |
| "selected_route": selected["route_name"], | |
| "selected_route_polarity_inverted": selected_route_polarity_inverted, | |
| "routes": route_records, | |
| }, | |
| "metrics": { | |
| "pair_count": int(selected["pair_count"]), | |
| "filtered_inlier_count": int(selected["inlier_count"]), | |
| "ransac_inlier_count": int(selected["ransac_inlier_count"]), | |
| }, | |
| "outputs": { | |
| "octa_raw": relpath(octa_raw_out, output_dir), | |
| "wfcfp_raw": relpath(wfcfp_raw_out, output_dir), | |
| "octa_vessel": relpath(octa_vessel_out, output_dir), | |
| "wfcfp_vessel": relpath(wfcfp_vessel_out, output_dir), | |
| "filtered_pairs_norm": relpath(filtered_norm_path, output_dir), | |
| "filtered_pairs_input_space_pixels": relpath(filtered_input_path, output_dir), | |
| "filtered_pairs_restored_space_pixels": relpath(filtered_restored_path, output_dir), | |
| "filtered_pairs_input_vis": relpath(kp_vis_input_out, output_dir) if kp_vis_input_out else None, | |
| "filtered_pairs_restored_vis": relpath(kp_vis_restored_out, output_dir), | |
| "all_pairs_vis": relpath(kp_vis_all_out, output_dir) if kp_vis_all_out else None, | |
| "vessel_overlay": relpath(vessel_overlay_out, output_dir), | |
| "raw_overlay": relpath(raw_overlay_out, output_dir), | |
| "warped_octa_raw": relpath(warped_octa_raw_out, output_dir) if warped_octa_raw_out else None, | |
| "warped_octa_vessel": relpath(warped_octa_vessel_out, output_dir) if warped_octa_vessel_out else None, | |
| "quadratic_polynomial_params": relpath(results_poly / "quadratic_polynomial_params.json", output_dir), | |
| "visual_summary_board": relpath(visual_summary_path, output_dir) if visual_summary_path else None, | |
| }, | |
| "step_logs": { | |
| "step1": relpath(logs_dir / "step1.log", output_dir), | |
| "step2": relpath(step2_log_path, output_dir) if step2_log_path.exists() else None, | |
| }, | |
| "poly_feature_order": poly_payload["feature_order"], | |
| } | |
| summary_path = output_dir / "summary.json" | |
| summary_path.write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8") | |
| print(json.dumps({"status": "success", "summary": str(summary_path)}, ensure_ascii=False)) | |
| if not args.keep_attempt_dirs: | |
| shutil.rmtree(step3_attempts, ignore_errors=True) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as exc: | |
| message = { | |
| "status": "failed", | |
| "error": str(exc), | |
| "traceback": traceback.format_exc(), | |
| } | |
| print(json.dumps(message, ensure_ascii=False, indent=2)) | |
| raise | |