ivus-segmentation / scripts /finetune /multitask /run_multitask_test_inference.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""Run test-set inference for multitask model (lumen segmentation + bifurcation classification)."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import tensorflow as tf
from deepivus.config import resolve_bifurcation_threshold_path
from scripts.finetune.shared.common import (
load_bifurcation_annotations,
load_lumen_annotations,
load_preprocessed_stack,
load_split_ids,
polygon_to_mask,
)
IMG_MEAN = tf.constant([60.3486], dtype=tf.float32)
def _extract_logits(model_output: object) -> tf.Tensor:
if isinstance(model_output, dict):
if not model_output:
raise RuntimeError("Model returned an empty dict output.")
return next(iter(model_output.values()))
if isinstance(model_output, (tuple, list)):
if not model_output:
raise RuntimeError("Model returned an empty sequence output.")
return model_output[0]
if tf.is_tensor(model_output):
return model_output
raise RuntimeError(f"Unsupported model output type: {type(model_output)!r}")
def _prepare_batch(images: np.ndarray) -> tf.Tensor:
x = tf.convert_to_tensor(images, dtype=tf.float32)
x = x - IMG_MEAN
x = tf.expand_dims(x, axis=-1)
x = tf.tile(x, [1, 1, 1, 3])
return x
def _binary_logit_from_multiclass(logits: tf.Tensor, lumen_class: int) -> tf.Tensor:
num_classes = int(logits.shape[-1])
if lumen_class < 0 or lumen_class >= num_classes:
raise ValueError(f"Invalid lumen_class={lumen_class} for num_classes={num_classes}")
positive = logits[..., lumen_class]
negatives = tf.concat([logits[..., :lumen_class], logits[..., lumen_class + 1 :]], axis=-1)
negative_logsumexp = tf.reduce_logsumexp(negatives, axis=-1)
return positive - negative_logsumexp
def _make_cls_head(num_seg_classes: int) -> tf.keras.Model:
inp = tf.keras.Input(shape=(None, None, num_seg_classes), name="seg_logits")
x = tf.keras.layers.GlobalAveragePooling2D()(inp)
x = tf.keras.layers.Dense(96, activation="relu")(x)
x = tf.keras.layers.Dropout(0.3)(x)
out = tf.keras.layers.Dense(1, activation="sigmoid", name="bifurcation_prob")(x)
return tf.keras.Model(inp, out, name="bifurcation_head")
def _build_arrays(annotations_bif, lumen_map: dict[str, object], diameter: int):
grouped: dict[Path, list] = {}
for ann in annotations_bif:
grouped.setdefault(ann.dicom_path, []).append(ann)
images = []
masks = []
bif_labels = []
has_mask = []
sample_ids = []
for dicom_path, ann_list in grouped.items():
stack = load_preprocessed_stack(dicom_path, diameter=diameter)
h, w = int(stack.shape[1]), int(stack.shape[2])
for ann in ann_list:
fidx = ann.frame_idx
if fidx < 0 or fidx >= stack.shape[0]:
continue
sid = ann.sample_id
images.append(stack[fidx])
bif_labels.append(1.0 if ann.bifurcation else 0.0)
sample_ids.append(sid)
lann = lumen_map.get(sid)
if lann is None:
masks.append(np.zeros((h, w), dtype=np.float32))
has_mask.append(0.0)
else:
m = polygon_to_mask(lann.lumen_x, lann.lumen_y, (h, w)).astype(np.float32)
masks.append(m)
has_mask.append(1.0 if np.any(m > 0.5) else 0.0)
if not images:
raise RuntimeError("No valid samples produced from frame bank.")
return (
np.stack(images, axis=0),
np.stack(masks, axis=0),
np.asarray(bif_labels, dtype=np.float32),
np.asarray(has_mask, dtype=np.float32),
sample_ids,
)
def _predict_probs(
base_model,
cls_head,
images: np.ndarray,
lumen_class: int,
batch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
seg_probs = []
cls_probs = []
for start in range(0, len(images), batch_size):
end = min(start + batch_size, len(images))
x = _prepare_batch(images[start:end])
logits = _extract_logits(base_model(x, training=False))
logits = tf.image.resize(logits, (images.shape[1], images.shape[2]))
bin_logit = _binary_logit_from_multiclass(logits, lumen_class=lumen_class)
seg = tf.math.sigmoid(bin_logit).numpy()
cls = tf.reshape(cls_head(logits, training=False), [-1]).numpy()
seg_probs.append(seg)
cls_probs.append(cls)
return np.concatenate(seg_probs, axis=0), np.concatenate(cls_probs, axis=0)
def _seg_metrics(seg_probs: np.ndarray, masks: np.ndarray, has_mask: np.ndarray, threshold: float = 0.5) -> dict[str, float]:
pred = seg_probs >= threshold
gt = masks >= 0.5
valid = has_mask > 0.5
inter = 0.0
union = 0.0
pred_sum = 0.0
gt_sum = 0.0
count = 0
for i in range(pred.shape[0]):
if not valid[i]:
continue
pi = pred[i]
gi = gt[i]
inter += float(np.logical_and(pi, gi).sum())
union += float(np.logical_or(pi, gi).sum())
pred_sum += float(pi.sum())
gt_sum += float(gi.sum())
count += 1
return {
"seg_count": int(count),
"seg_iou": float(inter / max(union, 1.0)),
"seg_dice": float((2.0 * inter) / max(pred_sum + gt_sum, 1.0)),
}
def _cls_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> dict[str, float]:
y_true_i = y_true.astype(np.int32)
y_pred = (y_prob >= threshold).astype(np.int32)
tp = int(np.sum((y_pred == 1) & (y_true_i == 1)))
fp = int(np.sum((y_pred == 1) & (y_true_i == 0)))
fn = int(np.sum((y_pred == 0) & (y_true_i == 1)))
tn = int(np.sum((y_pred == 0) & (y_true_i == 0)))
acc = float((tp + tn) / max(tp + tn + fp + fn, 1))
prec = float(tp / max(tp + fp, 1))
rec = float(tp / max(tp + fn, 1))
f1 = float((2.0 * prec * rec) / max(prec + rec, 1e-12))
if y_true_i.size > 1 and len(np.unique(y_true_i)) > 1:
auc_metric = tf.keras.metrics.AUC(curve="ROC")
auc_metric.update_state(y_true, y_prob)
auc = float(auc_metric.result().numpy())
else:
auc = float("nan")
return {
"threshold": float(threshold),
"cls_accuracy": acc,
"cls_precision": prec,
"cls_recall": rec,
"cls_f1": f1,
"cls_auc": auc,
"tp": tp,
"fp": fp,
"fn": fn,
"tn": tn,
}
def _select_threshold(y_true_val: np.ndarray, y_prob_val: np.ndarray, metric: str) -> tuple[float, dict[str, float]]:
candidates = np.linspace(0.05, 0.95, 91)
best = None
for th in candidates:
row = _cls_metrics(y_true_val, y_prob_val, threshold=float(th))
score = row[metric]
if best is None or score > best[metric]:
best = row
return float(best["threshold"]), best
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--frame-bank-root", type=Path, default=Path("evals/frame_bank_merged"))
parser.add_argument("--split-json", type=Path, default=Path("evals/splits/ivus_split_merged_600.json"))
parser.add_argument("--base-model-dir", type=Path, default=Path("models/multitask/lumen_multitask_base"))
parser.add_argument("--cls-head-path", type=Path, default=Path("models/multitask/bifurcation_head.keras"))
parser.add_argument("--lumen-class", type=int, default=1)
parser.add_argument("--diameter", type=int, default=67)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--threshold", type=float, default=None, help="Fixed bif threshold; if omitted, selected on val.")
parser.add_argument("--select-metric", type=str, default="cls_f1", choices=["cls_f1", "cls_accuracy", "cls_precision", "cls_recall", "cls_auc"])
parser.add_argument("--output-json", type=Path, default=Path("output/multitask_test_inference.json"))
parser.add_argument(
"--save-threshold",
action=argparse.BooleanOptionalAction,
default=True,
help="Persist selected threshold to <cls_head_dir>/threshold.json for runtime inference.",
)
args = parser.parse_args()
bif_anns = load_bifurcation_annotations(args.frame_bank_root)
lumen_anns = load_lumen_annotations(args.frame_bank_root)
if not bif_anns:
raise RuntimeError(f"No bifurcation annotations under: {args.frame_bank_root}")
split_ids = load_split_ids(args.split_json)
train_ids = split_ids["train"]
val_ids = split_ids["val"]
test_ids = split_ids["test"]
keep_ids = train_ids | val_ids | test_ids
bif_anns = [a for a in bif_anns if a.sample_id in keep_ids]
lumen_map = {a.sample_id: a for a in lumen_anns}
images, masks, bif_labels, has_mask, sample_ids = _build_arrays(bif_anns, lumen_map=lumen_map, diameter=args.diameter)
idx_val = np.asarray([i for i, sid in enumerate(sample_ids) if sid in val_ids], dtype=np.int64)
idx_test = np.asarray([i for i, sid in enumerate(sample_ids) if sid in test_ids], dtype=np.int64)
if len(idx_test) == 0:
raise RuntimeError("No test samples found in multitask arrays.")
base_model = tf.saved_model.load(str(args.base_model_dir))
cls_head = tf.keras.models.load_model(args.cls_head_path)
seg_val, cls_val = (None, None)
if args.threshold is None:
if len(idx_val) == 0:
raise RuntimeError("Threshold selection requested but val split is empty.")
_, cls_val = _predict_probs(base_model, cls_head, images[idx_val], lumen_class=args.lumen_class, batch_size=args.batch_size)
selected_threshold, val_best = _select_threshold(bif_labels[idx_val], cls_val, metric=args.select_metric)
threshold_info = {
"method": "validation_sweep",
"metric": args.select_metric,
"selected_threshold": float(selected_threshold),
"val_best": val_best,
}
else:
selected_threshold = float(args.threshold)
threshold_info = {"method": "fixed", "selected_threshold": float(selected_threshold)}
seg_test, cls_test = _predict_probs(base_model, cls_head, images[idx_test], lumen_class=args.lumen_class, batch_size=args.batch_size)
seg_metrics = _seg_metrics(seg_test, masks[idx_test], has_mask[idx_test], threshold=0.5)
cls_metrics = _cls_metrics(bif_labels[idx_test], cls_test, threshold=selected_threshold)
payload = {
"base_model_dir": str(args.base_model_dir),
"cls_head_path": str(args.cls_head_path),
"split_json": str(args.split_json),
"num_test_samples": int(len(idx_test)),
"num_test_with_lumen": int(np.sum(has_mask[idx_test] > 0.5)),
"threshold_info": threshold_info,
"segmentation_metrics": seg_metrics,
"bifurcation_metrics": cls_metrics,
}
args.output_json.parent.mkdir(parents=True, exist_ok=True)
with args.output_json.open("w", encoding="utf-8") as fp:
json.dump(payload, fp, indent=2)
if args.save_threshold:
threshold_path = resolve_bifurcation_threshold_path(model_path=args.cls_head_path)
threshold_path.parent.mkdir(parents=True, exist_ok=True)
threshold_payload = {
"selected_threshold": float(selected_threshold),
"selection": threshold_info,
"source_split_json": str(args.split_json),
"source_model_path": str(args.cls_head_path),
"source_base_model_dir": str(args.base_model_dir),
}
with threshold_path.open("w", encoding="utf-8") as fp:
json.dump(threshold_payload, fp, indent=2)
print(f"Test samples: {len(idx_test)} (with lumen gt: {int(np.sum(has_mask[idx_test] > 0.5))})")
print(
f"Seg Dice={seg_metrics['seg_dice']:.4f} IoU={seg_metrics['seg_iou']:.4f} | "
f"Cls Acc={cls_metrics['cls_accuracy']:.4f} AUC={cls_metrics['cls_auc']:.4f} F1={cls_metrics['cls_f1']:.4f} "
f"(th={selected_threshold:.3f})"
)
print(f"Saved: {args.output_json}")
if args.save_threshold:
print(f"Saved threshold: {threshold_path}")
if __name__ == "__main__":
main()