ivus-segmentation / scripts /analysis /generate_project_memo_assets.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""Generate figures and numeric summaries for the project fine-tuning memo."""
from __future__ import annotations
import json
import math
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parents[2]
ASSET_DIR = ROOT / "docs" / "memo_assets"
ASSET_DIR.mkdir(parents=True, exist_ok=True)
# ---------------------------- Helpers ----------------------------
def read_json(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def read_json_first(paths: list[Path], required: bool = True) -> dict:
for path in paths:
if path.exists():
return read_json(path)
if required:
raise FileNotFoundError(f"None of the candidate files exist: {paths}")
return {}
def wilson_interval(successes: int, n: int, z: float = 1.96) -> tuple[float, float]:
if n <= 0:
return float("nan"), float("nan")
p = successes / n
denom = 1.0 + (z * z / n)
center = (p + (z * z / (2 * n))) / denom
half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z / (4 * n * n)))
return max(0.0, center - half), min(1.0, center + half)
def ensure_style():
plt.rcParams.update(
{
"font.size": 10,
"axes.grid": True,
"grid.alpha": 0.22,
"figure.facecolor": "white",
"axes.facecolor": "#fafafa",
}
)
# ---------------------------- Load inputs ----------------------------
multitask_summary = read_json_first(
[
ROOT / "models" / "multitask" / "multitask_summary.json",
ROOT / "output" / "training_outputs" / "multitask_summary.json",
],
required=False,
)
multitask_test = read_json_first(
[
ROOT / "output" / "multitask_test_inference.json",
ROOT / "output" / "training_outputs" / "multitask_test_inference.json",
]
)
lumen_summary = read_json_first(
[
ROOT / "models" / "standalone" / "lumen" / "finetune_summary.json",
ROOT / "output" / "training_outputs" / "lumen_finetune_summary.json",
]
)
standalone_test = read_json_first(
[
ROOT / "output" / "bifurcation_classifier" / "test_inference.json",
ROOT / "output" / "training_outputs" / "bifurcation_classifier" / "test_inference.json",
]
)
threshold_payload = read_json_first(
[
ROOT / "models" / "multitask" / "threshold.json",
ROOT / "models" / "standalone" / "bifurcation" / "threshold.json",
ROOT / "models" / "bifurcation" / "threshold.json",
ROOT / "output" / "bifurcation_classifier" / "threshold.json",
]
)
dataset_balance = read_json_first(
[
ROOT / "output" / "dataset_balance" / "dataset_balance_summary.json",
ROOT / "docs" / "memo_assets" / "dataset_balance_summary.json",
]
)
# ---------------------------- Figure 1: Training curves (multitask) ----------------------------
ensure_style()
history = multitask_summary.get("history", [])
if history:
epochs = [int(r["epoch"]) for r in history]
train_total = [float(r.get("total_loss", np.nan)) for r in history]
val_total = [float(r.get("val_total_loss", np.nan)) for r in history]
val_dice = [float(r.get("val_seg_dice", np.nan)) for r in history]
val_iou = [float(r.get("val_seg_iou", np.nan)) for r in history]
val_f1 = [float(r.get("val_cls_f1", np.nan)) for r in history]
val_auc = [float(r.get("val_cls_auc", np.nan)) for r in history]
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes[0, 0].plot(epochs, train_total, label="train total loss", color="#1f77b4")
axes[0, 0].plot(epochs, val_total, label="val total loss", color="#d62728")
axes[0, 0].set_title("Multitask Total Loss")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].legend()
axes[0, 1].plot(epochs, val_dice, label="val seg dice", color="#2ca02c")
axes[0, 1].plot(epochs, val_iou, label="val seg IoU", color="#17becf")
axes[0, 1].set_title("Validation Segmentation Metrics")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Score")
axes[0, 1].set_ylim(0.75, 0.97)
axes[0, 1].legend()
axes[1, 0].plot(epochs, val_f1, label="val cls F1", color="#9467bd")
axes[1, 0].plot(epochs, val_auc, label="val cls AUC", color="#ff7f0e")
axes[1, 0].set_title("Validation Classification Metrics")
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Score")
axes[1, 0].set_ylim(0.0, 1.0)
axes[1, 0].legend()
# Stability plot: train seg/cls losses (if available)
train_seg = [float(r.get("seg_loss", np.nan)) for r in history]
train_cls = [float(r.get("cls_loss", np.nan)) for r in history]
axes[1, 1].plot(epochs, train_seg, label="train seg loss", color="#8c564b")
axes[1, 1].plot(epochs, train_cls, label="train cls loss", color="#e377c2")
axes[1, 1].set_title("Train Task Loss Components")
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Loss")
axes[1, 1].legend()
fig.suptitle("Multitask Fine-tuning Dynamics", fontsize=13)
fig.tight_layout()
fig.savefig(ASSET_DIR / "multitask_training_dynamics.png", dpi=220)
plt.close(fig)
# ---------------------------- Figure 2: Lumen fine-tune curves ----------------------------
lh = lumen_summary.get("history", [])
le = [int(float(r.get("epoch", i + 1))) for i, r in enumerate(lh)]
tr_loss = [float(r.get("train_loss", np.nan)) for r in lh]
val_loss = [float(r.get("val_loss", np.nan)) for r in lh]
val_dice_l = [float(r.get("val_dice", np.nan)) for r in lh]
val_iou_l = [float(r.get("val_iou", np.nan)) for r in lh]
fig, axes = plt.subplots(1, 2, figsize=(11, 4.2))
axes[0].plot(le, tr_loss, label="train loss", color="#1f77b4")
axes[0].plot(le, val_loss, label="val loss", color="#d62728")
axes[0].set_title("Lumen Fine-tune Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].legend()
axes[1].plot(le, val_dice_l, label="val dice", color="#2ca02c")
axes[1].plot(le, val_iou_l, label="val IoU", color="#17becf")
axes[1].set_title("Lumen Fine-tune Validation Metrics")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Score")
axes[1].set_ylim(0.82, 0.97)
axes[1].legend()
fig.tight_layout()
fig.savefig(ASSET_DIR / "lumen_finetune_dynamics.png", dpi=220)
plt.close(fig)
# ---------------------------- Figure 3: Standalone threshold sweep ----------------------------
sweep = standalone_test.get("threshold_selection", {}).get("val_threshold_sweep", [])
if sweep:
ths = np.array([float(r["threshold"]) for r in sweep], dtype=float)
f1s = np.array([float(r.get("f1", np.nan)) for r in sweep], dtype=float)
recs = np.array([float(r.get("recall", np.nan)) for r in sweep], dtype=float)
precs = np.array([float(r.get("precision", np.nan)) for r in sweep], dtype=float)
best_t = float(standalone_test.get("selected_threshold", ths[int(np.nanargmax(f1s))]))
best_i = int(np.argmin(np.abs(ths - best_t)))
fig, ax = plt.subplots(figsize=(8, 4.8))
ax.plot(ths, f1s, label="F1", color="#1f77b4", linewidth=2)
ax.plot(ths, recs, label="Recall", color="#2ca02c", alpha=0.75)
ax.plot(ths, precs, label="Precision", color="#d62728", alpha=0.75)
ax.scatter([ths[best_i]], [f1s[best_i]], color="black", s=60, zorder=3, label=f"Selected t={ths[best_i]:.3f}")
ax.axvline(ths[best_i], linestyle="--", color="black", linewidth=1)
ax.set_title("Validation Threshold Sweep (Standalone Bifurcation Model)")
ax.set_xlabel("Threshold")
ax.set_ylabel("Score")
ax.set_ylim(0.0, 1.0)
ax.legend(loc="best")
fig.tight_layout()
fig.savefig(ASSET_DIR / "standalone_threshold_sweep.png", dpi=220)
plt.close(fig)
# ---------------------------- Figure 4: Probability distributions ----------------------------
pred_rows = standalone_test.get("predictions", [])
if pred_rows:
y_true = np.array([int(r["label"]) for r in pred_rows], dtype=int)
y_prob = np.array([float(r["probability"]) for r in pred_rows], dtype=float)
th = float(standalone_test.get("selected_threshold", 0.5))
fig, ax = plt.subplots(figsize=(8, 4.6))
ax.hist(y_prob[y_true == 0], bins=24, alpha=0.6, label="True non-bifurcation", color="#1f77b4")
ax.hist(y_prob[y_true == 1], bins=24, alpha=0.6, label="True bifurcation", color="#d62728")
ax.axvline(th, color="black", linestyle="--", linewidth=1.3, label=f"Operating threshold={th:.3f}")
ax.set_title("Predicted Probability Distributions (Standalone Test Set)")
ax.set_xlabel("Predicted bifurcation probability")
ax.set_ylabel("Count")
ax.legend()
fig.tight_layout()
fig.savefig(ASSET_DIR / "standalone_probability_hist.png", dpi=220)
plt.close(fig)
# ---------------------------- Figure 5: Pipeline diagram ----------------------------
fig, ax = plt.subplots(figsize=(13, 5.8))
ax.axis("off")
# Draw boxes manually
boxes = {
"data": (0.03, 0.62, 0.2, 0.25, "Frame-bank JSONL + DICOM\n(train/val/test split)"),
"pre": (0.28, 0.62, 0.2, 0.25, "Preprocess\n(center black circle,\n3-channel replication)"),
"base": (0.53, 0.62, 0.2, 0.25, "Pretrained lumen base\n(SavedModel logits)"),
"seg": (0.78, 0.72, 0.19, 0.14, "Segmentation branch\n(BCE + Dice)") ,
"cls": (0.78, 0.52, 0.19, 0.14, "Classification head\n(GAP + Dense + Sigmoid)") ,
"loss": (0.53, 0.18, 0.26, 0.22, "Total loss\n= w_seg * seg_loss\n+ w_cls * cls_loss\n(mask seg loss by has_mask)"),
"eval": (0.84, 0.18, 0.13, 0.22, "Val threshold sweep\n→ threshold.json\n→ runtime inference"),
}
for _, (x, y, w, h, txt) in boxes.items():
rect = plt.Rectangle((x, y), w, h, facecolor="#f7f7f7", edgecolor="#333333", linewidth=1.2)
ax.add_patch(rect)
ax.text(x + w / 2, y + h / 2, txt, ha="center", va="center", fontsize=10)
def arrow(x0, y0, x1, y1):
ax.annotate("", xy=(x1, y1), xytext=(x0, y0), arrowprops=dict(arrowstyle="->", linewidth=1.4, color="#333333"))
arrow(0.23, 0.745, 0.28, 0.745)
arrow(0.48, 0.745, 0.53, 0.745)
arrow(0.73, 0.745, 0.78, 0.79)
arrow(0.73, 0.71, 0.78, 0.59)
arrow(0.875, 0.72, 0.66, 0.40)
arrow(0.875, 0.52, 0.66, 0.40)
arrow(0.79, 0.29, 0.84, 0.29)
ax.set_title("Training and Inference Design: Pretrained Base + Multitask Head", fontsize=13)
fig.tight_layout()
fig.savefig(ASSET_DIR / "multitask_pipeline_diagram.png", dpi=220)
plt.close(fig)
# ---------------------------- Figure 6: Metric comparison ----------------------------
mt = multitask_test["bifurcation_metrics"]
ms = multitask_test["segmentation_metrics"]
lumen_test = lumen_summary["final_test_metrics"]
labels = ["Lumen IoU", "Lumen Dice", "Bif Acc", "Bif F1", "Bif AUC"]
values = [
float(ms.get("seg_iou", np.nan)),
float(ms.get("seg_dice", np.nan)),
float(mt.get("cls_accuracy", np.nan)),
float(mt.get("cls_f1", np.nan)),
float(mt.get("cls_auc", np.nan)),
]
fig, ax = plt.subplots(figsize=(8.2, 4.4))
bars = ax.bar(labels, values, color=["#17becf", "#2ca02c", "#1f77b4", "#9467bd", "#ff7f0e"])
ax.set_ylim(0.0, 1.0)
ax.set_ylabel("Score")
ax.set_title("Multitask Test Metrics Snapshot")
for b, v in zip(bars, values):
ax.text(b.get_x() + b.get_width() / 2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9)
fig.tight_layout()
fig.savefig(ASSET_DIR / "multitask_test_metric_snapshot.png", dpi=220)
plt.close(fig)
# ---------------------------- Summary JSON ----------------------------
cls = multitask_test["bifurcation_metrics"]
seg = multitask_test["segmentation_metrics"]
n_cls = int(cls["tp"] + cls["fp"] + cls["fn"] + cls["tn"])
acc_ci = wilson_interval(int(cls["tp"] + cls["tn"]), n_cls)
rec_ci = wilson_interval(int(cls["tp"]), int(cls["tp"] + cls["fn"]))
prec_ci = wilson_interval(int(cls["tp"]), int(cls["tp"] + cls["fp"]))
best_epoch = None
if history:
best_epoch = min(history, key=lambda r: float(r.get("val_total_loss", np.inf))).get("epoch")
summary = {
"data": {
"split_counts": dataset_balance.get("counts", {}),
"bifurcation_positive_rate": dataset_balance.get("bifurcation_positive_rate", {}),
"lumen_coverage_rate": dataset_balance.get("lumen_coverage_rate", {}),
},
"multitask": {
"epochs_completed": len(history),
"best_epoch_by_val_total_loss": best_epoch,
"best_val_total_loss": multitask_summary.get("best_val_total_loss"),
"test_seg_metrics": seg,
"test_cls_metrics": cls,
"cls_accuracy_wilson95": acc_ci,
"cls_precision_wilson95": prec_ci,
"cls_recall_wilson95": rec_ci,
"threshold": multitask_test.get("threshold_info", {}).get("selected_threshold"),
},
"lumen_finetune": {
"num_samples": lumen_summary.get("num_samples"),
"test_metrics": lumen_summary.get("final_test_metrics", {}),
},
"standalone_bifurcation": {
"selected_threshold": standalone_test.get("selected_threshold"),
"test_metrics": standalone_test.get("metrics", {}),
},
"runtime_threshold": threshold_payload.get("selected_threshold"),
}
with (ASSET_DIR / "memo_summary.json").open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
print("Saved assets to", ASSET_DIR)
for p in sorted(ASSET_DIR.glob("*")):
print(p.name)