Instructions to use Aditya2162/ivus-segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use Aditya2162/ivus-segmentation with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://Aditya2162/ivus-segmentation") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """Generate figures and numeric summaries for the project fine-tuning memo.""" | |
| from __future__ import annotations | |
| import json | |
| import math | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| ROOT = Path(__file__).resolve().parents[2] | |
| ASSET_DIR = ROOT / "docs" / "memo_assets" | |
| ASSET_DIR.mkdir(parents=True, exist_ok=True) | |
| # ---------------------------- Helpers ---------------------------- | |
| def read_json(path: Path) -> dict: | |
| with path.open("r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def read_json_first(paths: list[Path], required: bool = True) -> dict: | |
| for path in paths: | |
| if path.exists(): | |
| return read_json(path) | |
| if required: | |
| raise FileNotFoundError(f"None of the candidate files exist: {paths}") | |
| return {} | |
| def wilson_interval(successes: int, n: int, z: float = 1.96) -> tuple[float, float]: | |
| if n <= 0: | |
| return float("nan"), float("nan") | |
| p = successes / n | |
| denom = 1.0 + (z * z / n) | |
| center = (p + (z * z / (2 * n))) / denom | |
| half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z / (4 * n * n))) | |
| return max(0.0, center - half), min(1.0, center + half) | |
| def ensure_style(): | |
| plt.rcParams.update( | |
| { | |
| "font.size": 10, | |
| "axes.grid": True, | |
| "grid.alpha": 0.22, | |
| "figure.facecolor": "white", | |
| "axes.facecolor": "#fafafa", | |
| } | |
| ) | |
| # ---------------------------- Load inputs ---------------------------- | |
| multitask_summary = read_json_first( | |
| [ | |
| ROOT / "models" / "multitask" / "multitask_summary.json", | |
| ROOT / "output" / "training_outputs" / "multitask_summary.json", | |
| ], | |
| required=False, | |
| ) | |
| multitask_test = read_json_first( | |
| [ | |
| ROOT / "output" / "multitask_test_inference.json", | |
| ROOT / "output" / "training_outputs" / "multitask_test_inference.json", | |
| ] | |
| ) | |
| lumen_summary = read_json_first( | |
| [ | |
| ROOT / "models" / "standalone" / "lumen" / "finetune_summary.json", | |
| ROOT / "output" / "training_outputs" / "lumen_finetune_summary.json", | |
| ] | |
| ) | |
| standalone_test = read_json_first( | |
| [ | |
| ROOT / "output" / "bifurcation_classifier" / "test_inference.json", | |
| ROOT / "output" / "training_outputs" / "bifurcation_classifier" / "test_inference.json", | |
| ] | |
| ) | |
| threshold_payload = read_json_first( | |
| [ | |
| ROOT / "models" / "multitask" / "threshold.json", | |
| ROOT / "models" / "standalone" / "bifurcation" / "threshold.json", | |
| ROOT / "models" / "bifurcation" / "threshold.json", | |
| ROOT / "output" / "bifurcation_classifier" / "threshold.json", | |
| ] | |
| ) | |
| dataset_balance = read_json_first( | |
| [ | |
| ROOT / "output" / "dataset_balance" / "dataset_balance_summary.json", | |
| ROOT / "docs" / "memo_assets" / "dataset_balance_summary.json", | |
| ] | |
| ) | |
| # ---------------------------- Figure 1: Training curves (multitask) ---------------------------- | |
| ensure_style() | |
| history = multitask_summary.get("history", []) | |
| if history: | |
| epochs = [int(r["epoch"]) for r in history] | |
| train_total = [float(r.get("total_loss", np.nan)) for r in history] | |
| val_total = [float(r.get("val_total_loss", np.nan)) for r in history] | |
| val_dice = [float(r.get("val_seg_dice", np.nan)) for r in history] | |
| val_iou = [float(r.get("val_seg_iou", np.nan)) for r in history] | |
| val_f1 = [float(r.get("val_cls_f1", np.nan)) for r in history] | |
| val_auc = [float(r.get("val_cls_auc", np.nan)) for r in history] | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 8)) | |
| axes[0, 0].plot(epochs, train_total, label="train total loss", color="#1f77b4") | |
| axes[0, 0].plot(epochs, val_total, label="val total loss", color="#d62728") | |
| axes[0, 0].set_title("Multitask Total Loss") | |
| axes[0, 0].set_xlabel("Epoch") | |
| axes[0, 0].set_ylabel("Loss") | |
| axes[0, 0].legend() | |
| axes[0, 1].plot(epochs, val_dice, label="val seg dice", color="#2ca02c") | |
| axes[0, 1].plot(epochs, val_iou, label="val seg IoU", color="#17becf") | |
| axes[0, 1].set_title("Validation Segmentation Metrics") | |
| axes[0, 1].set_xlabel("Epoch") | |
| axes[0, 1].set_ylabel("Score") | |
| axes[0, 1].set_ylim(0.75, 0.97) | |
| axes[0, 1].legend() | |
| axes[1, 0].plot(epochs, val_f1, label="val cls F1", color="#9467bd") | |
| axes[1, 0].plot(epochs, val_auc, label="val cls AUC", color="#ff7f0e") | |
| axes[1, 0].set_title("Validation Classification Metrics") | |
| axes[1, 0].set_xlabel("Epoch") | |
| axes[1, 0].set_ylabel("Score") | |
| axes[1, 0].set_ylim(0.0, 1.0) | |
| axes[1, 0].legend() | |
| # Stability plot: train seg/cls losses (if available) | |
| train_seg = [float(r.get("seg_loss", np.nan)) for r in history] | |
| train_cls = [float(r.get("cls_loss", np.nan)) for r in history] | |
| axes[1, 1].plot(epochs, train_seg, label="train seg loss", color="#8c564b") | |
| axes[1, 1].plot(epochs, train_cls, label="train cls loss", color="#e377c2") | |
| axes[1, 1].set_title("Train Task Loss Components") | |
| axes[1, 1].set_xlabel("Epoch") | |
| axes[1, 1].set_ylabel("Loss") | |
| axes[1, 1].legend() | |
| fig.suptitle("Multitask Fine-tuning Dynamics", fontsize=13) | |
| fig.tight_layout() | |
| fig.savefig(ASSET_DIR / "multitask_training_dynamics.png", dpi=220) | |
| plt.close(fig) | |
| # ---------------------------- Figure 2: Lumen fine-tune curves ---------------------------- | |
| lh = lumen_summary.get("history", []) | |
| le = [int(float(r.get("epoch", i + 1))) for i, r in enumerate(lh)] | |
| tr_loss = [float(r.get("train_loss", np.nan)) for r in lh] | |
| val_loss = [float(r.get("val_loss", np.nan)) for r in lh] | |
| val_dice_l = [float(r.get("val_dice", np.nan)) for r in lh] | |
| val_iou_l = [float(r.get("val_iou", np.nan)) for r in lh] | |
| fig, axes = plt.subplots(1, 2, figsize=(11, 4.2)) | |
| axes[0].plot(le, tr_loss, label="train loss", color="#1f77b4") | |
| axes[0].plot(le, val_loss, label="val loss", color="#d62728") | |
| axes[0].set_title("Lumen Fine-tune Loss") | |
| axes[0].set_xlabel("Epoch") | |
| axes[0].set_ylabel("Loss") | |
| axes[0].legend() | |
| axes[1].plot(le, val_dice_l, label="val dice", color="#2ca02c") | |
| axes[1].plot(le, val_iou_l, label="val IoU", color="#17becf") | |
| axes[1].set_title("Lumen Fine-tune Validation Metrics") | |
| axes[1].set_xlabel("Epoch") | |
| axes[1].set_ylabel("Score") | |
| axes[1].set_ylim(0.82, 0.97) | |
| axes[1].legend() | |
| fig.tight_layout() | |
| fig.savefig(ASSET_DIR / "lumen_finetune_dynamics.png", dpi=220) | |
| plt.close(fig) | |
| # ---------------------------- Figure 3: Standalone threshold sweep ---------------------------- | |
| sweep = standalone_test.get("threshold_selection", {}).get("val_threshold_sweep", []) | |
| if sweep: | |
| ths = np.array([float(r["threshold"]) for r in sweep], dtype=float) | |
| f1s = np.array([float(r.get("f1", np.nan)) for r in sweep], dtype=float) | |
| recs = np.array([float(r.get("recall", np.nan)) for r in sweep], dtype=float) | |
| precs = np.array([float(r.get("precision", np.nan)) for r in sweep], dtype=float) | |
| best_t = float(standalone_test.get("selected_threshold", ths[int(np.nanargmax(f1s))])) | |
| best_i = int(np.argmin(np.abs(ths - best_t))) | |
| fig, ax = plt.subplots(figsize=(8, 4.8)) | |
| ax.plot(ths, f1s, label="F1", color="#1f77b4", linewidth=2) | |
| ax.plot(ths, recs, label="Recall", color="#2ca02c", alpha=0.75) | |
| ax.plot(ths, precs, label="Precision", color="#d62728", alpha=0.75) | |
| ax.scatter([ths[best_i]], [f1s[best_i]], color="black", s=60, zorder=3, label=f"Selected t={ths[best_i]:.3f}") | |
| ax.axvline(ths[best_i], linestyle="--", color="black", linewidth=1) | |
| ax.set_title("Validation Threshold Sweep (Standalone Bifurcation Model)") | |
| ax.set_xlabel("Threshold") | |
| ax.set_ylabel("Score") | |
| ax.set_ylim(0.0, 1.0) | |
| ax.legend(loc="best") | |
| fig.tight_layout() | |
| fig.savefig(ASSET_DIR / "standalone_threshold_sweep.png", dpi=220) | |
| plt.close(fig) | |
| # ---------------------------- Figure 4: Probability distributions ---------------------------- | |
| pred_rows = standalone_test.get("predictions", []) | |
| if pred_rows: | |
| y_true = np.array([int(r["label"]) for r in pred_rows], dtype=int) | |
| y_prob = np.array([float(r["probability"]) for r in pred_rows], dtype=float) | |
| th = float(standalone_test.get("selected_threshold", 0.5)) | |
| fig, ax = plt.subplots(figsize=(8, 4.6)) | |
| ax.hist(y_prob[y_true == 0], bins=24, alpha=0.6, label="True non-bifurcation", color="#1f77b4") | |
| ax.hist(y_prob[y_true == 1], bins=24, alpha=0.6, label="True bifurcation", color="#d62728") | |
| ax.axvline(th, color="black", linestyle="--", linewidth=1.3, label=f"Operating threshold={th:.3f}") | |
| ax.set_title("Predicted Probability Distributions (Standalone Test Set)") | |
| ax.set_xlabel("Predicted bifurcation probability") | |
| ax.set_ylabel("Count") | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(ASSET_DIR / "standalone_probability_hist.png", dpi=220) | |
| plt.close(fig) | |
| # ---------------------------- Figure 5: Pipeline diagram ---------------------------- | |
| fig, ax = plt.subplots(figsize=(13, 5.8)) | |
| ax.axis("off") | |
| # Draw boxes manually | |
| boxes = { | |
| "data": (0.03, 0.62, 0.2, 0.25, "Frame-bank JSONL + DICOM\n(train/val/test split)"), | |
| "pre": (0.28, 0.62, 0.2, 0.25, "Preprocess\n(center black circle,\n3-channel replication)"), | |
| "base": (0.53, 0.62, 0.2, 0.25, "Pretrained lumen base\n(SavedModel logits)"), | |
| "seg": (0.78, 0.72, 0.19, 0.14, "Segmentation branch\n(BCE + Dice)") , | |
| "cls": (0.78, 0.52, 0.19, 0.14, "Classification head\n(GAP + Dense + Sigmoid)") , | |
| "loss": (0.53, 0.18, 0.26, 0.22, "Total loss\n= w_seg * seg_loss\n+ w_cls * cls_loss\n(mask seg loss by has_mask)"), | |
| "eval": (0.84, 0.18, 0.13, 0.22, "Val threshold sweep\n→ threshold.json\n→ runtime inference"), | |
| } | |
| for _, (x, y, w, h, txt) in boxes.items(): | |
| rect = plt.Rectangle((x, y), w, h, facecolor="#f7f7f7", edgecolor="#333333", linewidth=1.2) | |
| ax.add_patch(rect) | |
| ax.text(x + w / 2, y + h / 2, txt, ha="center", va="center", fontsize=10) | |
| def arrow(x0, y0, x1, y1): | |
| ax.annotate("", xy=(x1, y1), xytext=(x0, y0), arrowprops=dict(arrowstyle="->", linewidth=1.4, color="#333333")) | |
| arrow(0.23, 0.745, 0.28, 0.745) | |
| arrow(0.48, 0.745, 0.53, 0.745) | |
| arrow(0.73, 0.745, 0.78, 0.79) | |
| arrow(0.73, 0.71, 0.78, 0.59) | |
| arrow(0.875, 0.72, 0.66, 0.40) | |
| arrow(0.875, 0.52, 0.66, 0.40) | |
| arrow(0.79, 0.29, 0.84, 0.29) | |
| ax.set_title("Training and Inference Design: Pretrained Base + Multitask Head", fontsize=13) | |
| fig.tight_layout() | |
| fig.savefig(ASSET_DIR / "multitask_pipeline_diagram.png", dpi=220) | |
| plt.close(fig) | |
| # ---------------------------- Figure 6: Metric comparison ---------------------------- | |
| mt = multitask_test["bifurcation_metrics"] | |
| ms = multitask_test["segmentation_metrics"] | |
| lumen_test = lumen_summary["final_test_metrics"] | |
| labels = ["Lumen IoU", "Lumen Dice", "Bif Acc", "Bif F1", "Bif AUC"] | |
| values = [ | |
| float(ms.get("seg_iou", np.nan)), | |
| float(ms.get("seg_dice", np.nan)), | |
| float(mt.get("cls_accuracy", np.nan)), | |
| float(mt.get("cls_f1", np.nan)), | |
| float(mt.get("cls_auc", np.nan)), | |
| ] | |
| fig, ax = plt.subplots(figsize=(8.2, 4.4)) | |
| bars = ax.bar(labels, values, color=["#17becf", "#2ca02c", "#1f77b4", "#9467bd", "#ff7f0e"]) | |
| ax.set_ylim(0.0, 1.0) | |
| ax.set_ylabel("Score") | |
| ax.set_title("Multitask Test Metrics Snapshot") | |
| for b, v in zip(bars, values): | |
| ax.text(b.get_x() + b.get_width() / 2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9) | |
| fig.tight_layout() | |
| fig.savefig(ASSET_DIR / "multitask_test_metric_snapshot.png", dpi=220) | |
| plt.close(fig) | |
| # ---------------------------- Summary JSON ---------------------------- | |
| cls = multitask_test["bifurcation_metrics"] | |
| seg = multitask_test["segmentation_metrics"] | |
| n_cls = int(cls["tp"] + cls["fp"] + cls["fn"] + cls["tn"]) | |
| acc_ci = wilson_interval(int(cls["tp"] + cls["tn"]), n_cls) | |
| rec_ci = wilson_interval(int(cls["tp"]), int(cls["tp"] + cls["fn"])) | |
| prec_ci = wilson_interval(int(cls["tp"]), int(cls["tp"] + cls["fp"])) | |
| best_epoch = None | |
| if history: | |
| best_epoch = min(history, key=lambda r: float(r.get("val_total_loss", np.inf))).get("epoch") | |
| summary = { | |
| "data": { | |
| "split_counts": dataset_balance.get("counts", {}), | |
| "bifurcation_positive_rate": dataset_balance.get("bifurcation_positive_rate", {}), | |
| "lumen_coverage_rate": dataset_balance.get("lumen_coverage_rate", {}), | |
| }, | |
| "multitask": { | |
| "epochs_completed": len(history), | |
| "best_epoch_by_val_total_loss": best_epoch, | |
| "best_val_total_loss": multitask_summary.get("best_val_total_loss"), | |
| "test_seg_metrics": seg, | |
| "test_cls_metrics": cls, | |
| "cls_accuracy_wilson95": acc_ci, | |
| "cls_precision_wilson95": prec_ci, | |
| "cls_recall_wilson95": rec_ci, | |
| "threshold": multitask_test.get("threshold_info", {}).get("selected_threshold"), | |
| }, | |
| "lumen_finetune": { | |
| "num_samples": lumen_summary.get("num_samples"), | |
| "test_metrics": lumen_summary.get("final_test_metrics", {}), | |
| }, | |
| "standalone_bifurcation": { | |
| "selected_threshold": standalone_test.get("selected_threshold"), | |
| "test_metrics": standalone_test.get("metrics", {}), | |
| }, | |
| "runtime_threshold": threshold_payload.get("selected_threshold"), | |
| } | |
| with (ASSET_DIR / "memo_summary.json").open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2) | |
| print("Saved assets to", ASSET_DIR) | |
| for p in sorted(ASSET_DIR.glob("*")): | |
| print(p.name) | |