Instructions to use Aditya2162/ivus-segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use Aditya2162/ivus-segmentation with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://Aditya2162/ivus-segmentation") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """Run test-set inference for multitask model (lumen segmentation + bifurcation classification).""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import tensorflow as tf | |
| from deepivus.config import resolve_bifurcation_threshold_path | |
| from scripts.finetune.shared.common import ( | |
| load_bifurcation_annotations, | |
| load_lumen_annotations, | |
| load_preprocessed_stack, | |
| load_split_ids, | |
| polygon_to_mask, | |
| ) | |
| IMG_MEAN = tf.constant([60.3486], dtype=tf.float32) | |
| def _extract_logits(model_output: object) -> tf.Tensor: | |
| if isinstance(model_output, dict): | |
| if not model_output: | |
| raise RuntimeError("Model returned an empty dict output.") | |
| return next(iter(model_output.values())) | |
| if isinstance(model_output, (tuple, list)): | |
| if not model_output: | |
| raise RuntimeError("Model returned an empty sequence output.") | |
| return model_output[0] | |
| if tf.is_tensor(model_output): | |
| return model_output | |
| raise RuntimeError(f"Unsupported model output type: {type(model_output)!r}") | |
| def _prepare_batch(images: np.ndarray) -> tf.Tensor: | |
| x = tf.convert_to_tensor(images, dtype=tf.float32) | |
| x = x - IMG_MEAN | |
| x = tf.expand_dims(x, axis=-1) | |
| x = tf.tile(x, [1, 1, 1, 3]) | |
| return x | |
| def _binary_logit_from_multiclass(logits: tf.Tensor, lumen_class: int) -> tf.Tensor: | |
| num_classes = int(logits.shape[-1]) | |
| if lumen_class < 0 or lumen_class >= num_classes: | |
| raise ValueError(f"Invalid lumen_class={lumen_class} for num_classes={num_classes}") | |
| positive = logits[..., lumen_class] | |
| negatives = tf.concat([logits[..., :lumen_class], logits[..., lumen_class + 1 :]], axis=-1) | |
| negative_logsumexp = tf.reduce_logsumexp(negatives, axis=-1) | |
| return positive - negative_logsumexp | |
| def _make_cls_head(num_seg_classes: int) -> tf.keras.Model: | |
| inp = tf.keras.Input(shape=(None, None, num_seg_classes), name="seg_logits") | |
| x = tf.keras.layers.GlobalAveragePooling2D()(inp) | |
| x = tf.keras.layers.Dense(96, activation="relu")(x) | |
| x = tf.keras.layers.Dropout(0.3)(x) | |
| out = tf.keras.layers.Dense(1, activation="sigmoid", name="bifurcation_prob")(x) | |
| return tf.keras.Model(inp, out, name="bifurcation_head") | |
| def _build_arrays(annotations_bif, lumen_map: dict[str, object], diameter: int): | |
| grouped: dict[Path, list] = {} | |
| for ann in annotations_bif: | |
| grouped.setdefault(ann.dicom_path, []).append(ann) | |
| images = [] | |
| masks = [] | |
| bif_labels = [] | |
| has_mask = [] | |
| sample_ids = [] | |
| for dicom_path, ann_list in grouped.items(): | |
| stack = load_preprocessed_stack(dicom_path, diameter=diameter) | |
| h, w = int(stack.shape[1]), int(stack.shape[2]) | |
| for ann in ann_list: | |
| fidx = ann.frame_idx | |
| if fidx < 0 or fidx >= stack.shape[0]: | |
| continue | |
| sid = ann.sample_id | |
| images.append(stack[fidx]) | |
| bif_labels.append(1.0 if ann.bifurcation else 0.0) | |
| sample_ids.append(sid) | |
| lann = lumen_map.get(sid) | |
| if lann is None: | |
| masks.append(np.zeros((h, w), dtype=np.float32)) | |
| has_mask.append(0.0) | |
| else: | |
| m = polygon_to_mask(lann.lumen_x, lann.lumen_y, (h, w)).astype(np.float32) | |
| masks.append(m) | |
| has_mask.append(1.0 if np.any(m > 0.5) else 0.0) | |
| if not images: | |
| raise RuntimeError("No valid samples produced from frame bank.") | |
| return ( | |
| np.stack(images, axis=0), | |
| np.stack(masks, axis=0), | |
| np.asarray(bif_labels, dtype=np.float32), | |
| np.asarray(has_mask, dtype=np.float32), | |
| sample_ids, | |
| ) | |
| def _predict_probs( | |
| base_model, | |
| cls_head, | |
| images: np.ndarray, | |
| lumen_class: int, | |
| batch_size: int, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| seg_probs = [] | |
| cls_probs = [] | |
| for start in range(0, len(images), batch_size): | |
| end = min(start + batch_size, len(images)) | |
| x = _prepare_batch(images[start:end]) | |
| logits = _extract_logits(base_model(x, training=False)) | |
| logits = tf.image.resize(logits, (images.shape[1], images.shape[2])) | |
| bin_logit = _binary_logit_from_multiclass(logits, lumen_class=lumen_class) | |
| seg = tf.math.sigmoid(bin_logit).numpy() | |
| cls = tf.reshape(cls_head(logits, training=False), [-1]).numpy() | |
| seg_probs.append(seg) | |
| cls_probs.append(cls) | |
| return np.concatenate(seg_probs, axis=0), np.concatenate(cls_probs, axis=0) | |
| def _seg_metrics(seg_probs: np.ndarray, masks: np.ndarray, has_mask: np.ndarray, threshold: float = 0.5) -> dict[str, float]: | |
| pred = seg_probs >= threshold | |
| gt = masks >= 0.5 | |
| valid = has_mask > 0.5 | |
| inter = 0.0 | |
| union = 0.0 | |
| pred_sum = 0.0 | |
| gt_sum = 0.0 | |
| count = 0 | |
| for i in range(pred.shape[0]): | |
| if not valid[i]: | |
| continue | |
| pi = pred[i] | |
| gi = gt[i] | |
| inter += float(np.logical_and(pi, gi).sum()) | |
| union += float(np.logical_or(pi, gi).sum()) | |
| pred_sum += float(pi.sum()) | |
| gt_sum += float(gi.sum()) | |
| count += 1 | |
| return { | |
| "seg_count": int(count), | |
| "seg_iou": float(inter / max(union, 1.0)), | |
| "seg_dice": float((2.0 * inter) / max(pred_sum + gt_sum, 1.0)), | |
| } | |
| def _cls_metrics(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> dict[str, float]: | |
| y_true_i = y_true.astype(np.int32) | |
| y_pred = (y_prob >= threshold).astype(np.int32) | |
| tp = int(np.sum((y_pred == 1) & (y_true_i == 1))) | |
| fp = int(np.sum((y_pred == 1) & (y_true_i == 0))) | |
| fn = int(np.sum((y_pred == 0) & (y_true_i == 1))) | |
| tn = int(np.sum((y_pred == 0) & (y_true_i == 0))) | |
| acc = float((tp + tn) / max(tp + tn + fp + fn, 1)) | |
| prec = float(tp / max(tp + fp, 1)) | |
| rec = float(tp / max(tp + fn, 1)) | |
| f1 = float((2.0 * prec * rec) / max(prec + rec, 1e-12)) | |
| if y_true_i.size > 1 and len(np.unique(y_true_i)) > 1: | |
| auc_metric = tf.keras.metrics.AUC(curve="ROC") | |
| auc_metric.update_state(y_true, y_prob) | |
| auc = float(auc_metric.result().numpy()) | |
| else: | |
| auc = float("nan") | |
| return { | |
| "threshold": float(threshold), | |
| "cls_accuracy": acc, | |
| "cls_precision": prec, | |
| "cls_recall": rec, | |
| "cls_f1": f1, | |
| "cls_auc": auc, | |
| "tp": tp, | |
| "fp": fp, | |
| "fn": fn, | |
| "tn": tn, | |
| } | |
| def _select_threshold(y_true_val: np.ndarray, y_prob_val: np.ndarray, metric: str) -> tuple[float, dict[str, float]]: | |
| candidates = np.linspace(0.05, 0.95, 91) | |
| best = None | |
| for th in candidates: | |
| row = _cls_metrics(y_true_val, y_prob_val, threshold=float(th)) | |
| score = row[metric] | |
| if best is None or score > best[metric]: | |
| best = row | |
| return float(best["threshold"]), best | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--frame-bank-root", type=Path, default=Path("evals/frame_bank_merged")) | |
| parser.add_argument("--split-json", type=Path, default=Path("evals/splits/ivus_split_merged_600.json")) | |
| parser.add_argument("--base-model-dir", type=Path, default=Path("models/multitask/lumen_multitask_base")) | |
| parser.add_argument("--cls-head-path", type=Path, default=Path("models/multitask/bifurcation_head.keras")) | |
| parser.add_argument("--lumen-class", type=int, default=1) | |
| parser.add_argument("--diameter", type=int, default=67) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--threshold", type=float, default=None, help="Fixed bif threshold; if omitted, selected on val.") | |
| parser.add_argument("--select-metric", type=str, default="cls_f1", choices=["cls_f1", "cls_accuracy", "cls_precision", "cls_recall", "cls_auc"]) | |
| parser.add_argument("--output-json", type=Path, default=Path("output/multitask_test_inference.json")) | |
| parser.add_argument( | |
| "--save-threshold", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| help="Persist selected threshold to <cls_head_dir>/threshold.json for runtime inference.", | |
| ) | |
| args = parser.parse_args() | |
| bif_anns = load_bifurcation_annotations(args.frame_bank_root) | |
| lumen_anns = load_lumen_annotations(args.frame_bank_root) | |
| if not bif_anns: | |
| raise RuntimeError(f"No bifurcation annotations under: {args.frame_bank_root}") | |
| split_ids = load_split_ids(args.split_json) | |
| train_ids = split_ids["train"] | |
| val_ids = split_ids["val"] | |
| test_ids = split_ids["test"] | |
| keep_ids = train_ids | val_ids | test_ids | |
| bif_anns = [a for a in bif_anns if a.sample_id in keep_ids] | |
| lumen_map = {a.sample_id: a for a in lumen_anns} | |
| images, masks, bif_labels, has_mask, sample_ids = _build_arrays(bif_anns, lumen_map=lumen_map, diameter=args.diameter) | |
| idx_val = np.asarray([i for i, sid in enumerate(sample_ids) if sid in val_ids], dtype=np.int64) | |
| idx_test = np.asarray([i for i, sid in enumerate(sample_ids) if sid in test_ids], dtype=np.int64) | |
| if len(idx_test) == 0: | |
| raise RuntimeError("No test samples found in multitask arrays.") | |
| base_model = tf.saved_model.load(str(args.base_model_dir)) | |
| cls_head = tf.keras.models.load_model(args.cls_head_path) | |
| seg_val, cls_val = (None, None) | |
| if args.threshold is None: | |
| if len(idx_val) == 0: | |
| raise RuntimeError("Threshold selection requested but val split is empty.") | |
| _, cls_val = _predict_probs(base_model, cls_head, images[idx_val], lumen_class=args.lumen_class, batch_size=args.batch_size) | |
| selected_threshold, val_best = _select_threshold(bif_labels[idx_val], cls_val, metric=args.select_metric) | |
| threshold_info = { | |
| "method": "validation_sweep", | |
| "metric": args.select_metric, | |
| "selected_threshold": float(selected_threshold), | |
| "val_best": val_best, | |
| } | |
| else: | |
| selected_threshold = float(args.threshold) | |
| threshold_info = {"method": "fixed", "selected_threshold": float(selected_threshold)} | |
| seg_test, cls_test = _predict_probs(base_model, cls_head, images[idx_test], lumen_class=args.lumen_class, batch_size=args.batch_size) | |
| seg_metrics = _seg_metrics(seg_test, masks[idx_test], has_mask[idx_test], threshold=0.5) | |
| cls_metrics = _cls_metrics(bif_labels[idx_test], cls_test, threshold=selected_threshold) | |
| payload = { | |
| "base_model_dir": str(args.base_model_dir), | |
| "cls_head_path": str(args.cls_head_path), | |
| "split_json": str(args.split_json), | |
| "num_test_samples": int(len(idx_test)), | |
| "num_test_with_lumen": int(np.sum(has_mask[idx_test] > 0.5)), | |
| "threshold_info": threshold_info, | |
| "segmentation_metrics": seg_metrics, | |
| "bifurcation_metrics": cls_metrics, | |
| } | |
| args.output_json.parent.mkdir(parents=True, exist_ok=True) | |
| with args.output_json.open("w", encoding="utf-8") as fp: | |
| json.dump(payload, fp, indent=2) | |
| if args.save_threshold: | |
| threshold_path = resolve_bifurcation_threshold_path(model_path=args.cls_head_path) | |
| threshold_path.parent.mkdir(parents=True, exist_ok=True) | |
| threshold_payload = { | |
| "selected_threshold": float(selected_threshold), | |
| "selection": threshold_info, | |
| "source_split_json": str(args.split_json), | |
| "source_model_path": str(args.cls_head_path), | |
| "source_base_model_dir": str(args.base_model_dir), | |
| } | |
| with threshold_path.open("w", encoding="utf-8") as fp: | |
| json.dump(threshold_payload, fp, indent=2) | |
| print(f"Test samples: {len(idx_test)} (with lumen gt: {int(np.sum(has_mask[idx_test] > 0.5))})") | |
| print( | |
| f"Seg Dice={seg_metrics['seg_dice']:.4f} IoU={seg_metrics['seg_iou']:.4f} | " | |
| f"Cls Acc={cls_metrics['cls_accuracy']:.4f} AUC={cls_metrics['cls_auc']:.4f} F1={cls_metrics['cls_f1']:.4f} " | |
| f"(th={selected_threshold:.3f})" | |
| ) | |
| print(f"Saved: {args.output_json}") | |
| if args.save_threshold: | |
| print(f"Saved threshold: {threshold_path}") | |
| if __name__ == "__main__": | |
| main() | |