#!/usr/bin/env python3 """Run test-set inference for multitask model (lumen segmentation + bifurcation classification).""" from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import tensorflow as tf from deepivus.config import resolve_bifurcation_threshold_path from scripts.finetune.shared.common import ( load_bifurcation_annotations, load_lumen_annotations, load_preprocessed_stack, load_split_ids, polygon_to_mask, ) IMG_MEAN = tf.constant([60.3486], dtype=tf.float32) def _extract_logits(model_output: object) -> tf.Tensor: if isinstance(model_output, dict): if not model_output: raise RuntimeError("Model returned an empty dict output.") return next(iter(model_output.values())) if isinstance(model_output, (tuple, list)): if not model_output: raise RuntimeError("Model returned an empty sequence output.") return model_output[0] if tf.is_tensor(model_output): return model_output raise RuntimeError(f"Unsupported model output type: {type(model_output)!r}") def _prepare_batch(images: np.ndarray) -> tf.Tensor: x = tf.convert_to_tensor(images, dtype=tf.float32) x = x - IMG_MEAN x = tf.expand_dims(x, axis=-1) x = tf.tile(x, [1, 1, 1, 3]) return x def _binary_logit_from_multiclass(logits: tf.Tensor, lumen_class: int) -> tf.Tensor: num_classes = int(logits.shape[-1]) if lumen_class < 0 or lumen_class >= num_classes: raise ValueError(f"Invalid lumen_class={lumen_class} for num_classes={num_classes}") positive = logits[..., lumen_class] negatives = tf.concat([logits[..., :lumen_class], logits[..., lumen_class + 1 :]], axis=-1) negative_logsumexp = tf.reduce_logsumexp(negatives, axis=-1) return positive - negative_logsumexp def _make_cls_head(num_seg_classes: int) -> tf.keras.Model: inp = tf.keras.Input(shape=(None, None, num_seg_classes), name="seg_logits") x = tf.keras.layers.GlobalAveragePooling2D()(inp) x = tf.keras.layers.Dense(96, activation="relu")(x) x = tf.keras.layers.Dropout(0.3)(x) out = tf.keras.layers.Dense(1, activation="sigmoid", name="bifurcation_prob")(x) return tf.keras.Model(inp, out, name="bifurcation_head") def _build_arrays(annotations_bif, lumen_map: dict[str, object], diameter: int): grouped: dict[Path, list] = {} for ann in annotations_bif: grouped.setdefault(ann.dicom_path, []).append(ann) images = [] masks = [] bif_labels = [] has_mask = [] sample_ids = [] for dicom_path, ann_list in grouped.items(): stack = load_preprocessed_stack(dicom_path, diameter=diameter) h, w = int(stack.shape[1]), int(stack.shape[2]) for ann in ann_list: fidx = ann.frame_idx if fidx < 0 or fidx >= stack.shape[0]: continue sid = ann.sample_id images.append(stack[fidx]) bif_labels.append(1.0 if ann.bifurcation else 0.0) sample_ids.append(sid) lann = lumen_map.get(sid) if lann is None: masks.append(np.zeros((h, w), dtype=np.float32)) has_mask.append(0.0) else: m = polygon_to_mask(lann.lumen_x, lann.lumen_y, (h, w)).astype(np.float32) masks.append(m) has_mask.append(1.0 if np.any(m > 0.5) else 0.0) if not images: raise RuntimeError("No valid samples produced from frame bank.") return ( np.stack(images, axis=0), np.stack(masks, axis=0), np.asarray(bif_labels, dtype=np.float32), np.asarray(has_mask, dtype=np.float32), sample_ids, ) def _predict_probs( base_model, cls_head, images: np.ndarray, lumen_class: int, batch_size: int, ) -> tuple[np.ndarray, np.ndarray]: seg_probs = [] cls_probs = [] for start in range(0, len(images), batch_size): end = min(start + batch_size, len(images)) x = _prepare_batch(images[start:end]) logits = _extract_logits(base_model(x, training=False)) logits = tf.image.resize(logits, (images.shape[1], images.shape[2])) bin_logit = _binary_logit_from_multiclass(logits, lumen_class=lumen_class) seg = tf.math.sigmoid(bin_logit).numpy() cls = tf.reshape(cls_head(logits, training=False), [-1]).numpy() seg_probs.append(seg) cls_probs.append(cls) return np.concatenate(seg_probs, axis=0), np.concatenate(cls_probs, axis=0) def _seg_metrics(seg_probs: np.ndarray, masks: np.ndarray, has_mask: np.ndarray, threshold: float = 0.5) -> dict[str, float]: pred = seg_probs >= threshold gt = masks >= 0.5 valid = has_mask > 0.5 inter = 0.0 union = 0.0 pred_sum = 0.0 gt_sum = 0.0 count = 0 for i in range(pred.shape[0]): if not valid[i]: continue pi = pred[i] gi = gt[i] inter += float(np.logical_and(pi, gi).sum()) union += float(np.logical_or(pi, gi).sum()) pred_sum += float(pi.sum()) gt_sum += float(gi.sum()) count += 1 return { "seg_count": int(count), "seg_iou": float(inter / max(union, 1.0)), "seg_dice": float((2.0 * inter) / max(pred_sum + gt_sum, 1.0)), } def _cls_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> dict[str, float]: y_true_i = y_true.astype(np.int32) y_pred = (y_prob >= threshold).astype(np.int32) tp = int(np.sum((y_pred == 1) & (y_true_i == 1))) fp = int(np.sum((y_pred == 1) & (y_true_i == 0))) fn = int(np.sum((y_pred == 0) & (y_true_i == 1))) tn = int(np.sum((y_pred == 0) & (y_true_i == 0))) acc = float((tp + tn) / max(tp + tn + fp + fn, 1)) prec = float(tp / max(tp + fp, 1)) rec = float(tp / max(tp + fn, 1)) f1 = float((2.0 * prec * rec) / max(prec + rec, 1e-12)) if y_true_i.size > 1 and len(np.unique(y_true_i)) > 1: auc_metric = tf.keras.metrics.AUC(curve="ROC") auc_metric.update_state(y_true, y_prob) auc = float(auc_metric.result().numpy()) else: auc = float("nan") return { "threshold": float(threshold), "cls_accuracy": acc, "cls_precision": prec, "cls_recall": rec, "cls_f1": f1, "cls_auc": auc, "tp": tp, "fp": fp, "fn": fn, "tn": tn, } def _select_threshold(y_true_val: np.ndarray, y_prob_val: np.ndarray, metric: str) -> tuple[float, dict[str, float]]: candidates = np.linspace(0.05, 0.95, 91) best = None for th in candidates: row = _cls_metrics(y_true_val, y_prob_val, threshold=float(th)) score = row[metric] if best is None or score > best[metric]: best = row return float(best["threshold"]), best def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--frame-bank-root", type=Path, default=Path("evals/frame_bank_merged")) parser.add_argument("--split-json", type=Path, default=Path("evals/splits/ivus_split_merged_600.json")) parser.add_argument("--base-model-dir", type=Path, default=Path("models/multitask/lumen_multitask_base")) parser.add_argument("--cls-head-path", type=Path, default=Path("models/multitask/bifurcation_head.keras")) parser.add_argument("--lumen-class", type=int, default=1) parser.add_argument("--diameter", type=int, default=67) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--threshold", type=float, default=None, help="Fixed bif threshold; if omitted, selected on val.") parser.add_argument("--select-metric", type=str, default="cls_f1", choices=["cls_f1", "cls_accuracy", "cls_precision", "cls_recall", "cls_auc"]) parser.add_argument("--output-json", type=Path, default=Path("output/multitask_test_inference.json")) parser.add_argument( "--save-threshold", action=argparse.BooleanOptionalAction, default=True, help="Persist selected threshold to /threshold.json for runtime inference.", ) args = parser.parse_args() bif_anns = load_bifurcation_annotations(args.frame_bank_root) lumen_anns = load_lumen_annotations(args.frame_bank_root) if not bif_anns: raise RuntimeError(f"No bifurcation annotations under: {args.frame_bank_root}") split_ids = load_split_ids(args.split_json) train_ids = split_ids["train"] val_ids = split_ids["val"] test_ids = split_ids["test"] keep_ids = train_ids | val_ids | test_ids bif_anns = [a for a in bif_anns if a.sample_id in keep_ids] lumen_map = {a.sample_id: a for a in lumen_anns} images, masks, bif_labels, has_mask, sample_ids = _build_arrays(bif_anns, lumen_map=lumen_map, diameter=args.diameter) idx_val = np.asarray([i for i, sid in enumerate(sample_ids) if sid in val_ids], dtype=np.int64) idx_test = np.asarray([i for i, sid in enumerate(sample_ids) if sid in test_ids], dtype=np.int64) if len(idx_test) == 0: raise RuntimeError("No test samples found in multitask arrays.") base_model = tf.saved_model.load(str(args.base_model_dir)) cls_head = tf.keras.models.load_model(args.cls_head_path) seg_val, cls_val = (None, None) if args.threshold is None: if len(idx_val) == 0: raise RuntimeError("Threshold selection requested but val split is empty.") _, cls_val = _predict_probs(base_model, cls_head, images[idx_val], lumen_class=args.lumen_class, batch_size=args.batch_size) selected_threshold, val_best = _select_threshold(bif_labels[idx_val], cls_val, metric=args.select_metric) threshold_info = { "method": "validation_sweep", "metric": args.select_metric, "selected_threshold": float(selected_threshold), "val_best": val_best, } else: selected_threshold = float(args.threshold) threshold_info = {"method": "fixed", "selected_threshold": float(selected_threshold)} seg_test, cls_test = _predict_probs(base_model, cls_head, images[idx_test], lumen_class=args.lumen_class, batch_size=args.batch_size) seg_metrics = _seg_metrics(seg_test, masks[idx_test], has_mask[idx_test], threshold=0.5) cls_metrics = _cls_metrics(bif_labels[idx_test], cls_test, threshold=selected_threshold) payload = { "base_model_dir": str(args.base_model_dir), "cls_head_path": str(args.cls_head_path), "split_json": str(args.split_json), "num_test_samples": int(len(idx_test)), "num_test_with_lumen": int(np.sum(has_mask[idx_test] > 0.5)), "threshold_info": threshold_info, "segmentation_metrics": seg_metrics, "bifurcation_metrics": cls_metrics, } args.output_json.parent.mkdir(parents=True, exist_ok=True) with args.output_json.open("w", encoding="utf-8") as fp: json.dump(payload, fp, indent=2) if args.save_threshold: threshold_path = resolve_bifurcation_threshold_path(model_path=args.cls_head_path) threshold_path.parent.mkdir(parents=True, exist_ok=True) threshold_payload = { "selected_threshold": float(selected_threshold), "selection": threshold_info, "source_split_json": str(args.split_json), "source_model_path": str(args.cls_head_path), "source_base_model_dir": str(args.base_model_dir), } with threshold_path.open("w", encoding="utf-8") as fp: json.dump(threshold_payload, fp, indent=2) print(f"Test samples: {len(idx_test)} (with lumen gt: {int(np.sum(has_mask[idx_test] > 0.5))})") print( f"Seg Dice={seg_metrics['seg_dice']:.4f} IoU={seg_metrics['seg_iou']:.4f} | " f"Cls Acc={cls_metrics['cls_accuracy']:.4f} AUC={cls_metrics['cls_auc']:.4f} F1={cls_metrics['cls_f1']:.4f} " f"(th={selected_threshold:.3f})" ) print(f"Saved: {args.output_json}") if args.save_threshold: print(f"Saved threshold: {threshold_path}") if __name__ == "__main__": main()