CARe / InferCode /run_interface_b_single_pair.py
Hongyang-Li's picture
Upload 78 files
ffba4ae verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import datetime as dt
import json
import shutil
import subprocess
import sys
import threading
import traceback
from pathlib import Path
import cv2
import numpy as np
import torch
THIS_DIR = Path(__file__).resolve().parent
RELEASE_ROOT = THIS_DIR.parent
if str(RELEASE_ROOT) not in sys.path:
sys.path.insert(0, str(RELEASE_ROOT))
from hf_weight_manager import resolve_weight_paths
STEP1_DIR = RELEASE_ROOT / "Step1_VesselSeg"
STEP2_DIR = RELEASE_ROOT / "Step2_DetectCrop"
STEP3_DIR = RELEASE_ROOT / "Step3_Reg"
DEFAULT_STEP3_CONFIG = STEP3_DIR / "Src" / "config" / "test.yaml"
DEFAULT_STEP3_MODEL = STEP3_DIR / "Src" / "save" / "crop_vseg_vessel_111.pth"
DEFAULT_RETRY_FILTERED_INLIER_THRESHOLD = 20
FEATURE_NAMES = ["1", "x", "y", "x^2", "x*y", "y^2"]
DETECTION_SIZE = 800.0
COORD_SCALE = 1000.0
INVERTED_POLARITY_ROUTES = {"fallback_invert_no_crop"}
SEG_MODEL_CARE = "CARe-VesselSeg"
SEG_MODEL_UNET_DCP = "Broad domain retinal vessel segmentation [ICASSP 2025]"
DEFAULT_DEVICE = "auto"
_STEP2_API = None
_STEP2_RUNTIME_CACHE: dict[str, dict[str, object]] = {}
_RUNTIME_LOCK = threading.Lock()
def get_weight_paths() -> dict[str, Path]:
return resolve_weight_paths(RELEASE_ROOT)
def resolve_device(device: str) -> str:
normalized = str(device).strip().lower()
if normalized == "auto":
return "cuda:0" if torch.cuda.is_available() else "cpu"
return str(device)
def ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def run_cmd(cmd: list[str], log_path: Path) -> None:
ensure_dir(log_path.parent)
with log_path.open("w", encoding="utf-8") as log_file:
process = subprocess.run(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
text=True,
check=False,
)
if process.returncode != 0:
raise RuntimeError(f"Command failed ({process.returncode}), see log: {log_path}")
def infer_modality_from_path(image_path: Path) -> str:
stem = image_path.stem.upper()
if "OCTA" in stem or stem.endswith("_001"):
return "OCTA"
if "CFP" in stem or "FP" in stem or "FUNDUS" in stem:
return "CFP"
return "CFP"
def resolve_modalities(
query_modality_arg: str,
refer_modality_arg: str,
query_image_path: Path,
refer_image_path: Path,
) -> tuple[str, str]:
query_modality = query_modality_arg.upper()
refer_modality = refer_modality_arg.upper()
if query_modality == "AUTO":
query_modality = infer_modality_from_path(query_image_path)
if refer_modality == "AUTO":
refer_modality = infer_modality_from_path(refer_image_path)
return query_modality, refer_modality
def run_step1_broad_with_explicit_modalities(
*,
query_image: Path,
refer_image: Path,
query_stem: str,
refer_stem: str,
query_modality: str,
refer_modality: str,
device: str,
model_dir: Path | None,
output_dir: Path,
log_path: Path,
) -> tuple[Path, Path]:
project_root = STEP1_DIR / "ThirdParty" / "VesselSegProject_BroadDomain"
if not project_root.is_dir():
raise FileNotFoundError(f"BroadDomain project directory not found: {project_root}")
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from BroadDomainRetinalVesselSeg.UNet_DCP import inference_unet_dcp
from Code.VesselSeg.broad_domain_vessel_seg import BroadDomainVesselSegmenter
resolved_device = resolve_device(device)
ensure_dir(log_path.parent)
vessel_dir = ensure_dir(output_dir / "vessels")
if hasattr(inference_unet_dcp, "set_runtime_device"):
runtime_device = inference_unet_dcp.set_runtime_device(resolved_device)
else:
runtime_device = torch.device(resolved_device)
if hasattr(torch.backends, "mkldnn"):
torch.backends.mkldnn.enabled = runtime_device.type != "cpu"
inference_unet_dcp.device = runtime_device
segmenter = BroadDomainVesselSegmenter(
model_dir=None if model_dir is None else str(model_dir),
split_block=(4, 4),
)
query_bgr = read_as_bgr(query_image)
refer_bgr = read_as_bgr(refer_image)
query_vessel = segmenter.segment(query_bgr, modality=query_modality)
refer_vessel = segmenter.segment(refer_bgr, modality=refer_modality)
query_vessel_path = vessel_dir / f"{query_stem}_vessel.jpg"
refer_vessel_path = vessel_dir / f"{refer_stem}_vessel.jpg"
if not cv2.imwrite(str(query_vessel_path), query_vessel):
raise RuntimeError(f"Failed to write output: {query_vessel_path}")
if not cv2.imwrite(str(refer_vessel_path), refer_vessel):
raise RuntimeError(f"Failed to write output: {refer_vessel_path}")
summary = {
"backend": "BroadDomainVesselSegmenter(explicit_modality_single_pair)",
"query_image": str(query_image),
"refer_image": str(refer_image),
"query_modality": query_modality,
"refer_modality": refer_modality,
"device": str(runtime_device),
"query_vessel_path": str(query_vessel_path),
"refer_vessel_path": str(refer_vessel_path),
}
ensure_dir(output_dir)
(output_dir / "summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
with log_path.open("w", encoding="utf-8") as f:
f.write(json.dumps(summary, indent=2, ensure_ascii=False))
f.write("\n")
return query_vessel_path, refer_vessel_path
def read_as_bgr(path: Path) -> np.ndarray:
image = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
if image is None:
raise FileNotFoundError(f"Unable to read image: {path}")
if image.ndim == 2:
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
if image.shape[2] == 4:
return cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
def normalize_to_png(src: Path, dst: Path) -> Path:
image = read_as_bgr(src)
ensure_dir(dst.parent)
if not cv2.imwrite(str(dst), image):
raise RuntimeError(f"Failed to write output: {dst}")
return dst
def resolve_vessel_path(vessel_dir: Path, stem: str) -> Path:
candidates = [
f"{stem}_vessel.jpg",
f"{stem}_vessel.png",
f"{stem}_vessel.jpeg",
f"{stem}.jpg",
f"{stem}.png",
f"{stem}.jpeg",
]
for name in candidates:
path = vessel_dir / name
if path.exists():
return path
raise FileNotFoundError(f"Vessel segmentation image not found: {vessel_dir} (stem={stem})")
def compute_wfcfp_only_crop_box_1000(refer_txt_path: Path) -> np.ndarray:
points = np.asarray(np.loadtxt(refer_txt_path), dtype=np.float64)
points = np.atleast_2d(points)
if points.shape[0] < 2 or points.shape[1] < 2:
raise ValueError(f"Invalid wfCFP coordinate file format: {refer_txt_path}")
od_y_norm, od_x_norm = float(points[0, 0]), float(points[0, 1])
fovea_y_norm, fovea_x_norm = float(points[1, 0]), float(points[1, 1])
od_x_800 = od_x_norm * DETECTION_SIZE
fovea_x_800 = fovea_x_norm * DETECTION_SIZE
fovea_y_800 = fovea_y_norm * DETECTION_SIZE
width = abs(fovea_x_800 - od_x_800)
if width <= 1e-6:
raise ValueError(f"Cannot build wfCFP-only crop box, |fovea_x-od_x| is too small: {refer_txt_path}")
rate = COORD_SCALE / DETECTION_SIZE
crop_box = np.array(
[
[int((fovea_y_800 - width) * rate), int((fovea_x_800 - width) * rate)],
[int((fovea_y_800 + width) * rate), int((fovea_x_800 + width) * rate)],
],
dtype=np.int32,
)
crop_box = np.clip(crop_box, 0, int(COORD_SCALE))
if crop_box[1, 0] <= crop_box[0, 0] or crop_box[1, 1] <= crop_box[0, 1]:
raise ValueError(f"Invalid wfCFP-only crop box: {crop_box.tolist()}")
return crop_box
def load_step2_api():
global _STEP2_API
if _STEP2_API is not None:
return _STEP2_API
if str(STEP2_DIR) not in sys.path:
sys.path.insert(0, str(STEP2_DIR))
import run_detect_crop as step2_mod
weight_paths = get_weight_paths()
_STEP2_API = {
"module": step2_mod,
"infer_device": step2_mod.infer_device,
"build_detector": step2_mod.build_detector,
"run_wfcfp_detection_and_crop": step2_mod.run_wfcfp_detection_and_crop,
"run_octa_detection": step2_mod.run_octa_detection,
"defaults": {
"crop_output_size": step2_mod.DEFAULT_CROP_OUTPUT_SIZE,
"nms_threshold": step2_mod.DEFAULT_NMS_THRESHOLD,
"max_detection_attempts": step2_mod.DEFAULT_MAX_DETECTION_ATTEMPTS,
"max_rgb_fallback_attempts": step2_mod.DEFAULT_MAX_RGB_FALLBACK_ATTEMPTS,
"max_octa_detection_attempts": step2_mod.DEFAULT_MAX_OCTA_DETECTION_ATTEMPTS,
"max_octa_plain_fallback_attempts": step2_mod.DEFAULT_MAX_OCTA_PLAIN_FALLBACK_ATTEMPTS,
"octa_fovea_center_max_offset_norm": step2_mod.DEFAULT_OCTA_FOVEA_CENTER_MAX_OFFSET_NORM,
},
"weights": {
"wfcfp_fused": weight_paths["step2_wfcfp_fused"].resolve(),
"wfcfp_rgb_fallback": weight_paths["step2_wfcfp_rgb_fallback"].resolve(),
"octa_fused": weight_paths["step2_octa_fused"].resolve(),
"octa_plain_fallback": weight_paths["step2_octa_plain_fallback"].resolve(),
},
}
return _STEP2_API
def get_step2_runtime(device: str) -> dict[str, object]:
api = load_step2_api()
resolved_device = resolve_device(device)
with _RUNTIME_LOCK:
cached = _STEP2_RUNTIME_CACHE.get(resolved_device)
if cached is not None:
return cached
torch_device = api["infer_device"](resolved_device)
weights = api["weights"]
for weight_path in weights.values():
if not Path(weight_path).is_file():
raise FileNotFoundError(f"Step2 weights file does not exist: {weight_path}")
runtime = {
"device": str(torch_device),
"wfcfp_fused_detector": api["build_detector"](weights["wfcfp_fused"], torch_device),
"wfcfp_rgb_fallback_detector": api["build_detector"](weights["wfcfp_rgb_fallback"], torch_device),
"octa_fused_detector": api["build_detector"](weights["octa_fused"], torch_device),
"octa_plain_fallback_detector": api["build_detector"](weights["octa_plain_fallback"], torch_device),
"weights": weights,
}
_STEP2_RUNTIME_CACHE[resolved_device] = runtime
return runtime
def run_step2_inprocess(
*,
wfcfp_image: Path,
wfcfp_vessel: Path,
octa_image: Path,
octa_vessel: Path,
output_dir: Path,
device: str,
log_path: Path | None,
save_debug_artifacts: bool = True,
) -> dict[str, object]:
api = load_step2_api()
runtime = get_step2_runtime(device)
defaults = api["defaults"]
output_dir = output_dir.resolve()
ensure_dir(output_dir)
wfcfp_summary = api["run_wfcfp_detection_and_crop"](
wfcfp_image_path=wfcfp_image,
wfcfp_vessel_path=wfcfp_vessel,
output_dir=output_dir,
fused_detector=runtime["wfcfp_fused_detector"],
rgb_fallback_detector=runtime["wfcfp_rgb_fallback_detector"],
device=api["infer_device"](runtime["device"]),
crop_output_size=defaults["crop_output_size"],
nms_threshold=defaults["nms_threshold"],
max_detection_attempts=defaults["max_detection_attempts"],
max_rgb_fallback_attempts=defaults["max_rgb_fallback_attempts"],
weights_path=runtime["weights"]["wfcfp_fused"],
rgb_fallback_weights_path=runtime["weights"]["wfcfp_rgb_fallback"],
save_debug_artifacts=save_debug_artifacts,
summary_filename="wfcfp_summary.json",
)
octa_summary = api["run_octa_detection"](
octa_image_path=octa_image,
octa_vessel_path=octa_vessel,
wfcfp_points_norm_yx=np.asarray(wfcfp_summary["points_norm_yx"], dtype=np.float32),
output_dir=output_dir,
fused_detector=runtime["octa_fused_detector"],
plain_fallback_detector=runtime["octa_plain_fallback_detector"],
device=api["infer_device"](runtime["device"]),
nms_threshold=defaults["nms_threshold"],
max_detection_attempts=defaults["max_octa_detection_attempts"],
max_plain_fallback_attempts=defaults["max_octa_plain_fallback_attempts"],
octa_fovea_center_max_offset_norm=defaults["octa_fovea_center_max_offset_norm"],
weights_path=runtime["weights"]["octa_fused"],
plain_fallback_weights_path=runtime["weights"]["octa_plain_fallback"],
save_debug_artifacts=save_debug_artifacts,
summary_filename="octa_summary.json",
)
summary = dict(wfcfp_summary)
summary["has_octa_detection"] = True
summary["octa_summary_path"] = octa_summary["summary_path"]
summary["octa"] = octa_summary
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
if log_path is not None:
ensure_dir(log_path.parent)
log_payload = {
"status": "success",
"mode": "inprocess_cached",
"device": runtime["device"],
"summary_path": str(summary_path),
"weights": {k: str(v) for k, v in runtime["weights"].items()},
"save_debug_artifacts": bool(save_debug_artifacts),
}
log_path.write_text(json.dumps(log_payload, indent=2), encoding="utf-8")
return summary
def load_step3_api():
if str(STEP3_DIR) not in sys.path:
sys.path.insert(0, str(STEP3_DIR))
from run_register_pair import fit_polynomial_from_pairs, load_pair_points_pixels, run_registration_case, warmup_predictor_cache
return run_registration_case, fit_polynomial_from_pairs, load_pair_points_pixels, warmup_predictor_cache
def preload_runtime_models(
device: str = DEFAULT_DEVICE,
config_path: Path = DEFAULT_STEP3_CONFIG,
model_path: Path | None = DEFAULT_STEP3_MODEL,
) -> dict[str, object]:
resolved_device = resolve_device(device)
weight_paths = get_weight_paths()
step2_runtime = get_step2_runtime(resolved_device)
_, _, _, warmup_predictor_cache = load_step3_api()
selected_model_path: Path | None
if model_path is None:
selected_model_path = weight_paths["step3_model"]
else:
selected_model_path = Path(model_path)
if not selected_model_path.is_file():
try:
is_default = selected_model_path.resolve() == DEFAULT_STEP3_MODEL.resolve()
except Exception:
is_default = False
if is_default:
selected_model_path = weight_paths["step3_model"]
if selected_model_path is None or not selected_model_path.is_file():
raise FileNotFoundError(f"Step3 model not found: {selected_model_path}")
step3_warmup = warmup_predictor_cache(
str(config_path.resolve()),
device=resolved_device,
model_path=str(selected_model_path.resolve()),
)
return {
"device": resolved_device,
"step2": {
"device": step2_runtime["device"],
"weights": {k: str(v) for k, v in step2_runtime["weights"].items()},
},
"step3": {
**step3_warmup,
"model_path": str(selected_model_path.resolve()),
},
}
def route_template(route_name: str, output_dir: Path) -> dict[str, object]:
return {
"route_name": route_name,
"output_dir": str(output_dir.resolve()),
"status": "failed",
"error_message": "",
"pair_count": 0,
"inlier_count": 0,
"ransac_inlier_count": 0,
"register_summary": None,
"fit_eval": None,
"polarity_inverted": route_name in INVERTED_POLARITY_ROUTES,
}
def run_single_route(
*,
route_name: str,
output_dir: Path,
run_registration_case,
fit_polynomial_from_pairs,
query_image: Path,
query_vessel: Path,
refer_image: Path,
refer_vessel: Path,
crop_box_1000: Path | None,
config_path: Path,
model_path: Path | None,
device: str,
inlier_method: str,
legacy_script_crop: bool,
save_debug_artifacts: bool,
force_rerun: bool,
) -> dict[str, object]:
result = route_template(route_name, output_dir)
try:
if force_rerun and output_dir.exists():
shutil.rmtree(output_dir)
ensure_dir(output_dir.parent)
summary = run_registration_case(
query_image=str(query_image),
query_vessel=str(query_vessel),
refer_image=str(refer_image),
refer_vessel=str(refer_vessel),
crop_box_1000=None if crop_box_1000 is None else str(crop_box_1000),
query_crop_box_1000=None,
output_dir=str(output_dir),
device=device,
config_path=str(config_path),
model_path=None if model_path is None else str(model_path),
crop_mode="refer_box_only",
inlier_method=inlier_method,
legacy_script_crop=legacy_script_crop,
save_debug_artifacts=save_debug_artifacts,
)
result["register_summary"] = summary
if summary.get("status") != "success":
raise RuntimeError(summary.get("error_message") or "Step3 failed")
pair_txt_path = summary["pair_txt_path"]
fit_eval = fit_polynomial_from_pairs(pair_txt_path, inlier_method=inlier_method)
fit_ransac = fit_polynomial_from_pairs(pair_txt_path, inlier_method="RANSAC")
result["status"] = "success"
result["pair_count"] = int(fit_eval["pair_count"])
result["inlier_count"] = int(fit_eval["inlier_count"])
result["ransac_inlier_count"] = int(fit_ransac["inlier_count"])
result["fit_eval"] = fit_eval
except Exception as exc: # pragma: no cover
result["error_message"] = str(exc)
return result
def select_best_route(route_results: list[dict[str, object]]) -> dict[str, object] | None:
success = [item for item in route_results if item["status"] == "success"]
if not success:
return None
return max(success, key=lambda item: int(item["inlier_count"]))
def copy_file(src: Path, dst: Path) -> Path:
ensure_dir(dst.parent)
shutil.copyfile(src, dst)
return dst
def as_path(value: str | None) -> Path | None:
if value is None:
return None
return Path(value)
def dump_polynomial_params(fit_eval: dict[str, object], save_path: Path) -> dict[str, object]:
regr_x = fit_eval["regr_x"]
regr_y = fit_eval["regr_y"]
coef_x = [float(v) for v in np.asarray(regr_x.coef_).ravel()]
coef_y = [float(v) for v in np.asarray(regr_y.coef_).ravel()]
intercept_x = float(regr_x.intercept_)
intercept_y = float(regr_y.intercept_)
effective_x = [coef_x[0] + intercept_x, coef_x[1], coef_x[2], coef_x[3], coef_x[4], coef_x[5]]
effective_y = [coef_y[0] + intercept_y, coef_y[1], coef_y[2], coef_y[3], coef_y[4], coef_y[5]]
payload = {
"feature_order": FEATURE_NAMES,
"raw_linear_regression": {
"x_model": {"intercept": intercept_x, "coef": coef_x},
"y_model": {"intercept": intercept_y, "coef": coef_y},
"formula": "pred = intercept + sum_i coef[i] * phi_i, phi=[1,x,y,x^2,xy,y^2]",
},
"effective_quadratic_form": {
"x_pred": {
"a0": effective_x[0],
"a1_x": effective_x[1],
"a2_y": effective_x[2],
"a3_x2": effective_x[3],
"a4_xy": effective_x[4],
"a5_y2": effective_x[5],
},
"y_pred": {
"b0": effective_y[0],
"b1_x": effective_y[1],
"b2_y": effective_y[2],
"b3_x2": effective_y[3],
"b4_xy": effective_y[4],
"b5_y2": effective_y[5],
},
"formula": "x' = a0+a1*x+a2*y+a3*x^2+a4*x*y+a5*y^2; y' = b0+b1*x+b2*y+b3*x^2+b4*x*y+b5*y^2",
},
}
ensure_dir(save_path.parent)
save_path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
return payload
def relpath(path: Path, root: Path) -> str:
try:
return str(path.resolve().relative_to(root.resolve()))
except Exception:
return str(path.resolve())
def add_label(image: np.ndarray, label: str) -> np.ndarray:
canvas = image.copy()
cv2.rectangle(canvas, (0, 0), (canvas.shape[1], 40), (20, 20, 20), -1)
cv2.putText(
canvas,
label,
(12, 28),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(255, 255, 255),
2,
lineType=cv2.LINE_AA,
)
return canvas
def resize_keep(image: np.ndarray, size: tuple[int, int]) -> np.ndarray:
return cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
def create_visual_summary(
*,
save_path: Path,
octa_raw: Path,
wfcfp_raw: Path,
octa_vessel: Path,
wfcfp_vessel: Path,
keypoint_vis: Path,
all_keypoint_vis: Path,
vessel_overlay: Path,
raw_overlay: Path,
) -> Path:
tile_w, tile_h = 700, 700
panels = [
(read_as_bgr(octa_raw), "OCTA Raw"),
(read_as_bgr(wfcfp_raw), "wfCFP Raw"),
(read_as_bgr(octa_vessel), "OCTA Vessel"),
(read_as_bgr(wfcfp_vessel), "wfCFP Vessel"),
(read_as_bgr(keypoint_vis), "Filtered Keypoints"),
(read_as_bgr(vessel_overlay), "Vessel Overlay"),
(read_as_bgr(raw_overlay), "Raw Overlay"),
(read_as_bgr(all_keypoint_vis), "All Keypoints"),
]
prepared = []
for image, label in panels[:8]:
tile = resize_keep(image, (tile_w, tile_h))
tile = add_label(tile, label) if label else tile
prepared.append(tile)
row1 = np.hstack(prepared[0:4])
row2 = np.hstack(prepared[4:8])
board = np.vstack([row1, row2])
ensure_dir(save_path.parent)
cv2.imwrite(str(save_path), board)
return save_path
def to_black_bg_white_vessel(gray_or_bgr: np.ndarray) -> np.ndarray:
if gray_or_bgr.ndim == 3 and gray_or_bgr.shape[2] == 4:
gray = cv2.cvtColor(gray_or_bgr, cv2.COLOR_BGRA2GRAY)
elif gray_or_bgr.ndim == 3:
gray = cv2.cvtColor(gray_or_bgr, cv2.COLOR_BGR2GRAY)
else:
gray = gray_or_bgr
# If background is bright, invert once to keep vessel as white on black.
if float(np.mean(gray)) > 127.0:
gray = 255 - gray
return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
def draw_keypoint_pairs_on_black_vessel(
*,
query_img_path: Path,
refer_img_path: Path,
pair_points: np.ndarray,
save_path: Path,
title: str,
restore_from_inverted: bool = False,
coord_scale: float = COORD_SCALE,
) -> Path:
query_src = cv2.imread(str(query_img_path), cv2.IMREAD_UNCHANGED)
if query_src is None:
raise FileNotFoundError(f"Unable to read keypoint-visualization background image: {query_img_path}")
refer_src = cv2.imread(str(refer_img_path), cv2.IMREAD_UNCHANGED)
if refer_src is None:
raise FileNotFoundError(f"Unable to read keypoint-visualization background image: {refer_img_path}")
if restore_from_inverted:
query_src = cv2.bitwise_not(query_src)
refer_src = cv2.bitwise_not(refer_src)
query_bgr = to_black_bg_white_vessel(query_src)
refer_bgr = to_black_bg_white_vessel(refer_src)
# Normalize both panels to an identical shape, so left/right visualization
# is consistent before drawing pair connections.
h_a0, w_a0 = query_bgr.shape[:2]
h_b0, w_b0 = refer_bgr.shape[:2]
target_h = max(h_a0, h_b0)
target_w = max(w_a0, w_b0)
if h_a0 != target_h or w_a0 != target_w:
query_bgr = cv2.resize(query_bgr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
if h_b0 != target_h or w_b0 != target_w:
refer_bgr = cv2.resize(refer_bgr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
h_a, w_a = query_bgr.shape[:2]
h_b, w_b = refer_bgr.shape[:2]
canvas = np.zeros((max(h_a, h_b), w_a + w_b, 3), dtype=np.uint8)
canvas[0:h_a, 0:w_a] = query_bgr
canvas[0:h_b, w_a : w_a + w_b] = refer_bgr
for query_pt, refer_pt in zip(pair_points[:, :2], pair_points[:, 2:]):
query_x = float(query_pt[0]) * float(w_a) / float(coord_scale)
query_y = float(query_pt[1]) * float(h_a) / float(coord_scale)
refer_x = float(refer_pt[0]) * float(w_b) / float(coord_scale)
refer_y = float(refer_pt[1]) * float(h_b) / float(coord_scale)
pt_a = (int(round(query_x)), int(round(query_y)))
pt_b = (int(round(refer_x + w_a)), int(round(refer_y)))
cv2.line(canvas, pt_a, pt_b, (0, 255, 0), 1, lineType=cv2.LINE_AA)
cv2.circle(canvas, pt_a, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA)
cv2.circle(canvas, pt_b, 2, (0, 0, 255), -1, lineType=cv2.LINE_AA)
title_height = 70
vis = np.zeros((canvas.shape[0] + title_height, canvas.shape[1], 3), dtype=np.uint8)
vis[title_height:, :] = canvas
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.2
thickness = 2
text_size, _ = cv2.getTextSize(title, font, font_scale, thickness)
text_x = max(0, (vis.shape[1] - text_size[0]) // 2)
cv2.putText(
vis,
title,
(text_x, 45),
font,
font_scale,
(255, 255, 255),
thickness,
lineType=cv2.LINE_AA,
)
ensure_dir(save_path.parent)
cv2.imwrite(str(save_path), vis)
return save_path
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Single-pair full inference pipeline (Step1+Step2+Step3)")
parser.add_argument("--octa_image", required=True, type=Path, help="Path to OCTA image")
parser.add_argument("--wfcfp_image", required=True, type=Path, help="Path to wfCFP image")
parser.add_argument(
"--output_dir",
type=Path,
default=THIS_DIR / "Output" / f"single_pair_{dt.datetime.now().strftime('%Y%m%d_%H%M%S')}",
help="Output directory",
)
parser.add_argument(
"--seg_model",
choices=[SEG_MODEL_CARE, SEG_MODEL_UNET_DCP],
default=SEG_MODEL_CARE,
help=f"Step1 vessel segmentation model: {SEG_MODEL_CARE} or {SEG_MODEL_UNET_DCP}",
)
parser.add_argument(
"--inlier_method",
choices=["RANSAC", "LMEDS"],
default="LMEDS",
help="Step3 inlier filtering/fitting method",
)
parser.add_argument(
"--retry_filtered_inlier_threshold",
type=int,
default=DEFAULT_RETRY_FILTERED_INLIER_THRESHOLD,
help="Trigger retry when primary filtered inlier count is below this threshold",
)
parser.add_argument(
"--registration_mode",
choices=["interface_b_retry", "direct_no_crop"],
default="interface_b_retry",
help=(
"interface_b_retry: Interface-B logic (includes Step2 and retry routes); "
"direct_no_crop: Skip Step2 and cropping; run Step3 directly after Step1 segmentation."
),
)
parser.add_argument(
"--query_modality",
choices=["AUTO", "CFP", "OCTA"],
default="AUTO",
help=f"Query modality used by {SEG_MODEL_UNET_DCP}. AUTO infers by filename (contains OCTA/_001 -> OCTA, otherwise CFP).",
)
parser.add_argument(
"--refer_modality",
choices=["AUTO", "CFP", "OCTA"],
default="AUTO",
help=f"Refer modality used by {SEG_MODEL_UNET_DCP}. AUTO infers by filename (default CFP).",
)
parser.add_argument(
"--device",
type=str,
default=DEFAULT_DEVICE,
help=f"Device: auto/cpu/cuda:0. Default auto (cuda:0 if available, else cpu). Used by Step1({SEG_MODEL_CARE}), Step2, and Step3.",
)
parser.add_argument("--config_path", type=Path, default=DEFAULT_STEP3_CONFIG, help="Path to Step3 config file")
parser.add_argument("--model_path", type=Path, default=DEFAULT_STEP3_MODEL, help="Path to Step3 model weights")
parser.add_argument("--force_rerun", action="store_true", help="Delete output directory first if it already exists")
parser.add_argument("--continue_from_step1", action="store_true", help="Continue from current-sample Step1 outputs and run only Step2+Step3")
parser.add_argument("--lite_outputs", action="store_true", help="Generate only core outputs and skip non-essential heavy visualizations")
parser.add_argument("--keep_attempt_dirs", action="store_true", help="Keep intermediate Step3 route attempt directories")
parser.add_argument(
"--disable_nonessential_io",
action="store_true",
help="Disable non-essential debug artifact generation/writing in Step2/Step3 and pipeline copies.",
)
return parser
def main() -> None:
args = build_parser().parse_args()
resolved_device = resolve_device(args.device)
weight_paths = get_weight_paths()
octa_image = args.octa_image.resolve()
wfcfp_image = args.wfcfp_image.resolve()
output_dir = args.output_dir.resolve()
config_path = args.config_path.resolve()
model_path = args.model_path.resolve()
if not octa_image.is_file():
raise FileNotFoundError(f"OCTA image not found: {octa_image}")
if not wfcfp_image.is_file():
raise FileNotFoundError(f"wfCFP image not found: {wfcfp_image}")
if not config_path.is_file():
raise FileNotFoundError(f"Step3 config not found: {config_path}")
if not model_path.is_file():
try:
uses_default_model = model_path == DEFAULT_STEP3_MODEL.resolve()
except Exception:
uses_default_model = False
if uses_default_model:
model_path = weight_paths["step3_model"].resolve()
if not model_path.is_file():
raise FileNotFoundError(
f"Step3 model not found: {model_path}. "
"Set CARE_WEIGHTS_REPO_ID for snapshot-based weight loading if local files are absent."
)
if output_dir.exists() and args.force_rerun:
shutil.rmtree(output_dir)
ensure_dir(output_dir)
logs_dir = ensure_dir(output_dir / "logs")
work_dir = ensure_dir(output_dir / "work")
input_dir = ensure_dir(work_dir / "input_images")
step1_out = ensure_dir(work_dir / "step1")
step2_out = ensure_dir(work_dir / "step2")
step3_attempts = ensure_dir(work_dir / "step3" / "attempts")
step3_assets = ensure_dir(work_dir / "step3" / "assets")
step3_selected_dir = work_dir / "step3" / "selected"
results_dir = ensure_dir(output_dir / "results")
results_vessels = ensure_dir(results_dir / "vessels")
results_keypoints = ensure_dir(results_dir / "keypoints")
results_overlays = ensure_dir(results_dir / "overlays")
results_poly = ensure_dir(results_dir / "polynomial")
results_vis = ensure_dir(results_dir / "visualizations")
preload_runtime_models(device=resolved_device, config_path=config_path, model_path=model_path)
run_registration_case, fit_polynomial_from_pairs, load_pair_points_pixels, _ = load_step3_api()
query_modality, refer_modality = resolve_modalities(
query_modality_arg=args.query_modality,
refer_modality_arg=args.refer_modality,
query_image_path=octa_image,
refer_image_path=wfcfp_image,
)
requested_registration_mode = args.registration_mode
effective_registration_mode = requested_registration_mode
registration_mode_auto_forced = False
registration_mode_auto_reason = ""
if requested_registration_mode == "interface_b_retry" and query_modality == "CFP" and refer_modality == "CFP":
effective_registration_mode = "direct_no_crop"
registration_mode_auto_forced = True
registration_mode_auto_reason = "both_images_are_cfp"
octa_png = normalize_to_png(octa_image, input_dir / "octa.png")
wfcfp_png = normalize_to_png(wfcfp_image, input_dir / "wfcfp.png")
step1_reused_in_current_run = False
if args.continue_from_step1:
vessel_dir = step1_out / "vessels"
octa_vessel = resolve_vessel_path(vessel_dir, "octa")
wfcfp_vessel = resolve_vessel_path(vessel_dir, "wfcfp")
step1_reused_in_current_run = True
(logs_dir / "step1.log").write_text("Continue from current-run Step1 outputs\n", encoding="utf-8")
elif args.seg_model == SEG_MODEL_CARE:
step1_script = STEP1_DIR / "run_vessel_seg.py"
step1_cmd = [
sys.executable,
str(step1_script),
"--input_dir",
str(input_dir),
"--output_dir",
str(step1_out),
"--weights",
str(weight_paths["step1_care"].resolve()),
"--device",
resolved_device,
]
run_cmd(step1_cmd, logs_dir / "step1.log")
vessel_dir = step1_out / "vessels"
octa_vessel = resolve_vessel_path(vessel_dir, "octa")
wfcfp_vessel = resolve_vessel_path(vessel_dir, "wfcfp")
else:
octa_vessel, wfcfp_vessel = run_step1_broad_with_explicit_modalities(
query_image=octa_png,
refer_image=wfcfp_png,
query_stem="octa",
refer_stem="wfcfp",
query_modality=query_modality,
refer_modality=refer_modality,
device=resolved_device,
model_dir=weight_paths["step1_broad_dir"].resolve(),
output_dir=step1_out,
log_path=logs_dir / "step1.log",
)
step2_log_path = logs_dir / "step2.log"
wfcfp_txt: Path | None = None
if effective_registration_mode == "interface_b_retry":
run_step2_inprocess(
wfcfp_image=wfcfp_png,
wfcfp_vessel=wfcfp_vessel,
octa_image=octa_png,
octa_vessel=octa_vessel,
output_dir=step2_out,
device=resolved_device,
log_path=None if args.disable_nonessential_io else step2_log_path,
save_debug_artifacts=not args.disable_nonessential_io,
)
wfcfp_txt = step2_out / "wfcfp_od_fovea.txt"
octa_txt = step2_out / "octa_od_fovea.txt"
if not wfcfp_txt.is_file():
raise FileNotFoundError(f"Missing Step2 output: {wfcfp_txt}")
if not octa_txt.is_file():
raise FileNotFoundError(f"Missing Step2 output: {octa_txt}")
legacy_octa_txt = octa_vessel.with_name("octa.txt")
legacy_wfcfp_txt = wfcfp_vessel.with_name("wfcfp.txt")
copy_file(octa_txt, legacy_octa_txt)
copy_file(wfcfp_txt, legacy_wfcfp_txt)
common_kwargs = {
"run_registration_case": run_registration_case,
"fit_polynomial_from_pairs": fit_polynomial_from_pairs,
"query_image": octa_png,
"query_vessel": octa_vessel,
"refer_image": wfcfp_png,
"refer_vessel": wfcfp_vessel,
"config_path": config_path,
"model_path": model_path,
"device": resolved_device,
"inlier_method": args.inlier_method,
"save_debug_artifacts": not args.disable_nonessential_io,
"force_rerun": True,
}
route_results: list[dict[str, object]] = []
retry_reasons: list[str] = []
retry_triggered = False
if effective_registration_mode == "direct_no_crop":
full_box_path = step3_assets / "full_1000_box_1000.txt"
np.savetxt(full_box_path, np.array([[0, 0], [1000, 1000]], dtype=np.int32), fmt="%d")
direct_result = run_single_route(
route_name="direct_no_crop",
output_dir=step3_attempts / "direct_no_crop",
crop_box_1000=full_box_path,
legacy_script_crop=False,
**common_kwargs,
)
route_results.append(direct_result)
else:
primary = run_single_route(
route_name="primary",
output_dir=step3_attempts / "primary",
crop_box_1000=None,
legacy_script_crop=True,
**common_kwargs,
)
route_results.append(primary)
if primary["status"] != "success":
retry_reasons.append("primary_failed")
else:
if int(primary["inlier_count"]) < int(args.retry_filtered_inlier_threshold):
retry_reasons.append(f"filtered_inlier_below_{int(args.retry_filtered_inlier_threshold)}")
retry_triggered = len(retry_reasons) > 0
if retry_triggered:
if wfcfp_txt is None:
raise RuntimeError("wfcfp_txt must not be None in interface_b_retry mode")
wfcfp_only_box = compute_wfcfp_only_crop_box_1000(wfcfp_txt)
wfcfp_only_box_path = step3_assets / "wfcfp_only_crop_box_1000.txt"
np.savetxt(wfcfp_only_box_path, wfcfp_only_box, fmt="%d")
full_box_path = step3_assets / "full_1000_box_1000.txt"
np.savetxt(full_box_path, np.array([[0, 0], [1000, 1000]], dtype=np.int32), fmt="%d")
route_results.append(
run_single_route(
route_name="fallback_wfcfp_only_crop",
output_dir=step3_attempts / "fallback_wfcfp_only_crop",
crop_box_1000=wfcfp_only_box_path,
legacy_script_crop=False,
**common_kwargs,
)
)
route_results.append(
run_single_route(
route_name="fallback_invert_no_crop",
output_dir=step3_attempts / "fallback_invert_no_crop",
crop_box_1000=full_box_path,
legacy_script_crop=False,
**common_kwargs,
)
)
selected = select_best_route(route_results)
if selected is None:
messages = [f"{item['route_name']}: {item.get('error_message')}" for item in route_results]
raise RuntimeError("All routes failed: " + " | ".join(messages))
selected_route_dir = Path(str(selected["output_dir"]))
if not args.disable_nonessential_io:
if step3_selected_dir.exists():
shutil.rmtree(step3_selected_dir)
shutil.copytree(selected_route_dir, step3_selected_dir)
selected_register_summary = selected["register_summary"]
selected_fit_eval = selected["fit_eval"]
selected_route_polarity_inverted = bool(selected.get("polarity_inverted", False))
pair_input_path = as_path(selected_register_summary.get("pair_input_space_pixels_path"))
pair_restored_path = as_path(selected_register_summary.get("pair_restored_space_pixels_path"))
if pair_input_path is None or pair_restored_path is None:
raise RuntimeError("Selected route summary is missing keypoint pixel paths")
pair_input = load_pair_points_pixels(str(pair_input_path))
pair_restored = load_pair_points_pixels(str(pair_restored_path))
inlier_mask = np.asarray(selected_fit_eval["inlier_mask"], dtype=bool)
if len(pair_input) != len(inlier_mask) or len(pair_restored) != len(inlier_mask):
raise RuntimeError("Keypoint pair count does not match inlier-mask length")
filtered_input = pair_input[inlier_mask]
filtered_restored = pair_restored[inlier_mask]
filtered_norm = np.asarray(selected_fit_eval["inlier_pairs_norm"], dtype=np.float64)
filtered_norm_path = results_keypoints / "filtered_pairs_norm.txt"
filtered_input_path = results_keypoints / "filtered_pairs_input_space_pixels.txt"
filtered_restored_path = results_keypoints / "filtered_pairs_restored_space_pixels.txt"
np.savetxt(filtered_norm_path, filtered_norm, fmt="%.12f")
np.savetxt(filtered_input_path, filtered_input, fmt="%.6f")
np.savetxt(filtered_restored_path, filtered_restored, fmt="%.6f")
octa_vessel_out = copy_file(octa_vessel, results_vessels / "octa_vessel.png")
wfcfp_vessel_out = copy_file(wfcfp_vessel, results_vessels / "wfcfp_vessel.png")
vessel_overlay_src = Path(selected_register_summary["vessel_merged_display_path"])
raw_overlay_src = Path(selected_register_summary["actual_overlay_path"])
vessel_overlay_out = copy_file(vessel_overlay_src, results_overlays / "vessel_overlay.png")
raw_overlay_out = copy_file(raw_overlay_src, results_overlays / "raw_overlay.png")
warped_octa_raw_out = None
warped_octa_vessel_out = None
if not args.lite_outputs:
warped_octa_raw_out = copy_file(Path(selected_register_summary["warped_actual_path"]), results_overlays / "warped_octa_raw.png")
warped_octa_vessel_out = copy_file(Path(selected_register_summary["warped_vessel_path"]), results_overlays / "warped_octa_vessel.png")
debug_dir = Path(selected_register_summary["registration_debug_dir"])
query_restored = debug_dir / "match_query_restored.png"
refer_restored = debug_dir / "match_refer_restored.png"
query_input = debug_dir / "match_query_input.png"
refer_input = debug_dir / "match_refer_input.png"
if not query_restored.is_file():
query_restored = octa_vessel
if not refer_restored.is_file():
refer_restored = wfcfp_vessel
if not query_input.is_file():
query_input = octa_vessel
if not refer_input.is_file():
refer_input = wfcfp_vessel
kp_vis_input_out = None
if not args.lite_outputs:
kp_vis_input_out = draw_keypoint_pairs_on_black_vessel(
query_img_path=query_input,
refer_img_path=refer_input,
pair_points=filtered_input,
save_path=results_keypoints / "filtered_pairs_input_space_vis.png",
title=f"{args.inlier_method} Inlier Match (Input Space), #inlier: {len(filtered_input)}",
restore_from_inverted=selected_route_polarity_inverted,
)
kp_vis_restored_out = draw_keypoint_pairs_on_black_vessel(
query_img_path=query_restored,
refer_img_path=refer_restored,
pair_points=filtered_restored,
save_path=results_keypoints / "filtered_pairs_restored_space_vis.png",
title=f"{args.inlier_method} Inlier Match (Restored Space), #inlier: {len(filtered_restored)}",
restore_from_inverted=selected_route_polarity_inverted,
)
kp_vis_all_out = None
if not args.lite_outputs:
kp_vis_all_out = draw_keypoint_pairs_on_black_vessel(
query_img_path=query_restored,
refer_img_path=refer_restored,
pair_points=pair_restored,
save_path=results_keypoints / "all_pairs_vis.png",
title=f"All Match (Restored Space), #pair: {len(pair_restored)}",
restore_from_inverted=selected_route_polarity_inverted,
)
poly_payload = dump_polynomial_params(selected_fit_eval, results_poly / "quadratic_polynomial_params.json")
octa_raw_out = copy_file(octa_png, results_dir / "octa_raw.png")
wfcfp_raw_out = copy_file(wfcfp_png, results_dir / "wfcfp_raw.png")
visual_summary_path = None
if not args.lite_outputs:
visual_summary_path = create_visual_summary(
save_path=results_vis / "summary_board.png",
octa_raw=octa_raw_out,
wfcfp_raw=wfcfp_raw_out,
octa_vessel=octa_vessel_out,
wfcfp_vessel=wfcfp_vessel_out,
keypoint_vis=kp_vis_restored_out,
all_keypoint_vis=kp_vis_all_out,
vessel_overlay=vessel_overlay_out,
raw_overlay=raw_overlay_out,
)
route_records = []
for item in route_results:
route_records.append(
{
"route_name": item["route_name"],
"status": item["status"],
"error_message": item.get("error_message", ""),
"pair_count": int(item.get("pair_count", 0)),
"inlier_count": int(item.get("inlier_count", 0)),
"ransac_inlier_count": int(item.get("ransac_inlier_count", 0)),
"polarity_inverted": bool(item.get("polarity_inverted", False)),
"output_dir": relpath(Path(str(item["output_dir"])), output_dir),
}
)
summary = {
"status": "success",
"run_time": dt.datetime.now().isoformat(),
"input": {
"octa_image": str(octa_image),
"wfcfp_image": str(wfcfp_image),
"seg_model": args.seg_model,
"continue_from_step1": bool(args.continue_from_step1),
"step1_reused_in_current_run": bool(step1_reused_in_current_run),
"lite_outputs": bool(args.lite_outputs),
"disable_nonessential_io": bool(args.disable_nonessential_io),
"registration_mode_requested": requested_registration_mode,
"registration_mode_effective": effective_registration_mode,
"registration_mode_auto_forced": registration_mode_auto_forced,
"registration_mode_auto_reason": registration_mode_auto_reason,
"query_modality": query_modality,
"refer_modality": refer_modality,
"inlier_method": args.inlier_method,
"retry_filtered_inlier_threshold": int(args.retry_filtered_inlier_threshold),
"device": resolved_device,
"step3_config_path": str(config_path),
"step3_model_path": str(model_path),
},
"retry": {
"retry_triggered": retry_triggered,
"retry_reasons": retry_reasons,
"selected_route": selected["route_name"],
"selected_route_polarity_inverted": selected_route_polarity_inverted,
"routes": route_records,
},
"metrics": {
"pair_count": int(selected["pair_count"]),
"filtered_inlier_count": int(selected["inlier_count"]),
"ransac_inlier_count": int(selected["ransac_inlier_count"]),
},
"outputs": {
"octa_raw": relpath(octa_raw_out, output_dir),
"wfcfp_raw": relpath(wfcfp_raw_out, output_dir),
"octa_vessel": relpath(octa_vessel_out, output_dir),
"wfcfp_vessel": relpath(wfcfp_vessel_out, output_dir),
"filtered_pairs_norm": relpath(filtered_norm_path, output_dir),
"filtered_pairs_input_space_pixels": relpath(filtered_input_path, output_dir),
"filtered_pairs_restored_space_pixels": relpath(filtered_restored_path, output_dir),
"filtered_pairs_input_vis": relpath(kp_vis_input_out, output_dir) if kp_vis_input_out else None,
"filtered_pairs_restored_vis": relpath(kp_vis_restored_out, output_dir),
"all_pairs_vis": relpath(kp_vis_all_out, output_dir) if kp_vis_all_out else None,
"vessel_overlay": relpath(vessel_overlay_out, output_dir),
"raw_overlay": relpath(raw_overlay_out, output_dir),
"warped_octa_raw": relpath(warped_octa_raw_out, output_dir) if warped_octa_raw_out else None,
"warped_octa_vessel": relpath(warped_octa_vessel_out, output_dir) if warped_octa_vessel_out else None,
"quadratic_polynomial_params": relpath(results_poly / "quadratic_polynomial_params.json", output_dir),
"visual_summary_board": relpath(visual_summary_path, output_dir) if visual_summary_path else None,
},
"step_logs": {
"step1": relpath(logs_dir / "step1.log", output_dir),
"step2": relpath(step2_log_path, output_dir) if step2_log_path.exists() else None,
},
"poly_feature_order": poly_payload["feature_order"],
}
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
print(json.dumps({"status": "success", "summary": str(summary_path)}, ensure_ascii=False))
if not args.keep_attempt_dirs:
shutil.rmtree(step3_attempts, ignore_errors=True)
if __name__ == "__main__":
try:
main()
except Exception as exc:
message = {
"status": "failed",
"error": str(exc),
"traceback": traceback.format_exc(),
}
print(json.dumps(message, ensure_ascii=False, indent=2))
raise