#!/usr/bin/env python3 """Generate figures and numeric summaries for the project fine-tuning memo.""" from __future__ import annotations import json import math from pathlib import Path import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parents[2] ASSET_DIR = ROOT / "docs" / "memo_assets" ASSET_DIR.mkdir(parents=True, exist_ok=True) # ---------------------------- Helpers ---------------------------- def read_json(path: Path) -> dict: with path.open("r", encoding="utf-8") as f: return json.load(f) def read_json_first(paths: list[Path], required: bool = True) -> dict: for path in paths: if path.exists(): return read_json(path) if required: raise FileNotFoundError(f"None of the candidate files exist: {paths}") return {} def wilson_interval(successes: int, n: int, z: float = 1.96) -> tuple[float, float]: if n <= 0: return float("nan"), float("nan") p = successes / n denom = 1.0 + (z * z / n) center = (p + (z * z / (2 * n))) / denom half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z / (4 * n * n))) return max(0.0, center - half), min(1.0, center + half) def ensure_style(): plt.rcParams.update( { "font.size": 10, "axes.grid": True, "grid.alpha": 0.22, "figure.facecolor": "white", "axes.facecolor": "#fafafa", } ) # ---------------------------- Load inputs ---------------------------- multitask_summary = read_json_first( [ ROOT / "models" / "multitask" / "multitask_summary.json", ROOT / "output" / "training_outputs" / "multitask_summary.json", ], required=False, ) multitask_test = read_json_first( [ ROOT / "output" / "multitask_test_inference.json", ROOT / "output" / "training_outputs" / "multitask_test_inference.json", ] ) lumen_summary = read_json_first( [ ROOT / "models" / "standalone" / "lumen" / "finetune_summary.json", ROOT / "output" / "training_outputs" / "lumen_finetune_summary.json", ] ) standalone_test = read_json_first( [ ROOT / "output" / "bifurcation_classifier" / "test_inference.json", ROOT / "output" / "training_outputs" / "bifurcation_classifier" / "test_inference.json", ] ) threshold_payload = read_json_first( [ ROOT / "models" / "multitask" / "threshold.json", ROOT / "models" / "standalone" / "bifurcation" / "threshold.json", ROOT / "models" / "bifurcation" / "threshold.json", ROOT / "output" / "bifurcation_classifier" / "threshold.json", ] ) dataset_balance = read_json_first( [ ROOT / "output" / "dataset_balance" / "dataset_balance_summary.json", ROOT / "docs" / "memo_assets" / "dataset_balance_summary.json", ] ) # ---------------------------- Figure 1: Training curves (multitask) ---------------------------- ensure_style() history = multitask_summary.get("history", []) if history: epochs = [int(r["epoch"]) for r in history] train_total = [float(r.get("total_loss", np.nan)) for r in history] val_total = [float(r.get("val_total_loss", np.nan)) for r in history] val_dice = [float(r.get("val_seg_dice", np.nan)) for r in history] val_iou = [float(r.get("val_seg_iou", np.nan)) for r in history] val_f1 = [float(r.get("val_cls_f1", np.nan)) for r in history] val_auc = [float(r.get("val_cls_auc", np.nan)) for r in history] fig, axes = plt.subplots(2, 2, figsize=(12, 8)) axes[0, 0].plot(epochs, train_total, label="train total loss", color="#1f77b4") axes[0, 0].plot(epochs, val_total, label="val total loss", color="#d62728") axes[0, 0].set_title("Multitask Total Loss") axes[0, 0].set_xlabel("Epoch") axes[0, 0].set_ylabel("Loss") axes[0, 0].legend() axes[0, 1].plot(epochs, val_dice, label="val seg dice", color="#2ca02c") axes[0, 1].plot(epochs, val_iou, label="val seg IoU", color="#17becf") axes[0, 1].set_title("Validation Segmentation Metrics") axes[0, 1].set_xlabel("Epoch") axes[0, 1].set_ylabel("Score") axes[0, 1].set_ylim(0.75, 0.97) axes[0, 1].legend() axes[1, 0].plot(epochs, val_f1, label="val cls F1", color="#9467bd") axes[1, 0].plot(epochs, val_auc, label="val cls AUC", color="#ff7f0e") axes[1, 0].set_title("Validation Classification Metrics") axes[1, 0].set_xlabel("Epoch") axes[1, 0].set_ylabel("Score") axes[1, 0].set_ylim(0.0, 1.0) axes[1, 0].legend() # Stability plot: train seg/cls losses (if available) train_seg = [float(r.get("seg_loss", np.nan)) for r in history] train_cls = [float(r.get("cls_loss", np.nan)) for r in history] axes[1, 1].plot(epochs, train_seg, label="train seg loss", color="#8c564b") axes[1, 1].plot(epochs, train_cls, label="train cls loss", color="#e377c2") axes[1, 1].set_title("Train Task Loss Components") axes[1, 1].set_xlabel("Epoch") axes[1, 1].set_ylabel("Loss") axes[1, 1].legend() fig.suptitle("Multitask Fine-tuning Dynamics", fontsize=13) fig.tight_layout() fig.savefig(ASSET_DIR / "multitask_training_dynamics.png", dpi=220) plt.close(fig) # ---------------------------- Figure 2: Lumen fine-tune curves ---------------------------- lh = lumen_summary.get("history", []) le = [int(float(r.get("epoch", i + 1))) for i, r in enumerate(lh)] tr_loss = [float(r.get("train_loss", np.nan)) for r in lh] val_loss = [float(r.get("val_loss", np.nan)) for r in lh] val_dice_l = [float(r.get("val_dice", np.nan)) for r in lh] val_iou_l = [float(r.get("val_iou", np.nan)) for r in lh] fig, axes = plt.subplots(1, 2, figsize=(11, 4.2)) axes[0].plot(le, tr_loss, label="train loss", color="#1f77b4") axes[0].plot(le, val_loss, label="val loss", color="#d62728") axes[0].set_title("Lumen Fine-tune Loss") axes[0].set_xlabel("Epoch") axes[0].set_ylabel("Loss") axes[0].legend() axes[1].plot(le, val_dice_l, label="val dice", color="#2ca02c") axes[1].plot(le, val_iou_l, label="val IoU", color="#17becf") axes[1].set_title("Lumen Fine-tune Validation Metrics") axes[1].set_xlabel("Epoch") axes[1].set_ylabel("Score") axes[1].set_ylim(0.82, 0.97) axes[1].legend() fig.tight_layout() fig.savefig(ASSET_DIR / "lumen_finetune_dynamics.png", dpi=220) plt.close(fig) # ---------------------------- Figure 3: Standalone threshold sweep ---------------------------- sweep = standalone_test.get("threshold_selection", {}).get("val_threshold_sweep", []) if sweep: ths = np.array([float(r["threshold"]) for r in sweep], dtype=float) f1s = np.array([float(r.get("f1", np.nan)) for r in sweep], dtype=float) recs = np.array([float(r.get("recall", np.nan)) for r in sweep], dtype=float) precs = np.array([float(r.get("precision", np.nan)) for r in sweep], dtype=float) best_t = float(standalone_test.get("selected_threshold", ths[int(np.nanargmax(f1s))])) best_i = int(np.argmin(np.abs(ths - best_t))) fig, ax = plt.subplots(figsize=(8, 4.8)) ax.plot(ths, f1s, label="F1", color="#1f77b4", linewidth=2) ax.plot(ths, recs, label="Recall", color="#2ca02c", alpha=0.75) ax.plot(ths, precs, label="Precision", color="#d62728", alpha=0.75) ax.scatter([ths[best_i]], [f1s[best_i]], color="black", s=60, zorder=3, label=f"Selected t={ths[best_i]:.3f}") ax.axvline(ths[best_i], linestyle="--", color="black", linewidth=1) ax.set_title("Validation Threshold Sweep (Standalone Bifurcation Model)") ax.set_xlabel("Threshold") ax.set_ylabel("Score") ax.set_ylim(0.0, 1.0) ax.legend(loc="best") fig.tight_layout() fig.savefig(ASSET_DIR / "standalone_threshold_sweep.png", dpi=220) plt.close(fig) # ---------------------------- Figure 4: Probability distributions ---------------------------- pred_rows = standalone_test.get("predictions", []) if pred_rows: y_true = np.array([int(r["label"]) for r in pred_rows], dtype=int) y_prob = np.array([float(r["probability"]) for r in pred_rows], dtype=float) th = float(standalone_test.get("selected_threshold", 0.5)) fig, ax = plt.subplots(figsize=(8, 4.6)) ax.hist(y_prob[y_true == 0], bins=24, alpha=0.6, label="True non-bifurcation", color="#1f77b4") ax.hist(y_prob[y_true == 1], bins=24, alpha=0.6, label="True bifurcation", color="#d62728") ax.axvline(th, color="black", linestyle="--", linewidth=1.3, label=f"Operating threshold={th:.3f}") ax.set_title("Predicted Probability Distributions (Standalone Test Set)") ax.set_xlabel("Predicted bifurcation probability") ax.set_ylabel("Count") ax.legend() fig.tight_layout() fig.savefig(ASSET_DIR / "standalone_probability_hist.png", dpi=220) plt.close(fig) # ---------------------------- Figure 5: Pipeline diagram ---------------------------- fig, ax = plt.subplots(figsize=(13, 5.8)) ax.axis("off") # Draw boxes manually boxes = { "data": (0.03, 0.62, 0.2, 0.25, "Frame-bank JSONL + DICOM\n(train/val/test split)"), "pre": (0.28, 0.62, 0.2, 0.25, "Preprocess\n(center black circle,\n3-channel replication)"), "base": (0.53, 0.62, 0.2, 0.25, "Pretrained lumen base\n(SavedModel logits)"), "seg": (0.78, 0.72, 0.19, 0.14, "Segmentation branch\n(BCE + Dice)") , "cls": (0.78, 0.52, 0.19, 0.14, "Classification head\n(GAP + Dense + Sigmoid)") , "loss": (0.53, 0.18, 0.26, 0.22, "Total loss\n= w_seg * seg_loss\n+ w_cls * cls_loss\n(mask seg loss by has_mask)"), "eval": (0.84, 0.18, 0.13, 0.22, "Val threshold sweep\n→ threshold.json\n→ runtime inference"), } for _, (x, y, w, h, txt) in boxes.items(): rect = plt.Rectangle((x, y), w, h, facecolor="#f7f7f7", edgecolor="#333333", linewidth=1.2) ax.add_patch(rect) ax.text(x + w / 2, y + h / 2, txt, ha="center", va="center", fontsize=10) def arrow(x0, y0, x1, y1): ax.annotate("", xy=(x1, y1), xytext=(x0, y0), arrowprops=dict(arrowstyle="->", linewidth=1.4, color="#333333")) arrow(0.23, 0.745, 0.28, 0.745) arrow(0.48, 0.745, 0.53, 0.745) arrow(0.73, 0.745, 0.78, 0.79) arrow(0.73, 0.71, 0.78, 0.59) arrow(0.875, 0.72, 0.66, 0.40) arrow(0.875, 0.52, 0.66, 0.40) arrow(0.79, 0.29, 0.84, 0.29) ax.set_title("Training and Inference Design: Pretrained Base + Multitask Head", fontsize=13) fig.tight_layout() fig.savefig(ASSET_DIR / "multitask_pipeline_diagram.png", dpi=220) plt.close(fig) # ---------------------------- Figure 6: Metric comparison ---------------------------- mt = multitask_test["bifurcation_metrics"] ms = multitask_test["segmentation_metrics"] lumen_test = lumen_summary["final_test_metrics"] labels = ["Lumen IoU", "Lumen Dice", "Bif Acc", "Bif F1", "Bif AUC"] values = [ float(ms.get("seg_iou", np.nan)), float(ms.get("seg_dice", np.nan)), float(mt.get("cls_accuracy", np.nan)), float(mt.get("cls_f1", np.nan)), float(mt.get("cls_auc", np.nan)), ] fig, ax = plt.subplots(figsize=(8.2, 4.4)) bars = ax.bar(labels, values, color=["#17becf", "#2ca02c", "#1f77b4", "#9467bd", "#ff7f0e"]) ax.set_ylim(0.0, 1.0) ax.set_ylabel("Score") ax.set_title("Multitask Test Metrics Snapshot") for b, v in zip(bars, values): ax.text(b.get_x() + b.get_width() / 2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9) fig.tight_layout() fig.savefig(ASSET_DIR / "multitask_test_metric_snapshot.png", dpi=220) plt.close(fig) # ---------------------------- Summary JSON ---------------------------- cls = multitask_test["bifurcation_metrics"] seg = multitask_test["segmentation_metrics"] n_cls = int(cls["tp"] + cls["fp"] + cls["fn"] + cls["tn"]) acc_ci = wilson_interval(int(cls["tp"] + cls["tn"]), n_cls) rec_ci = wilson_interval(int(cls["tp"]), int(cls["tp"] + cls["fn"])) prec_ci = wilson_interval(int(cls["tp"]), int(cls["tp"] + cls["fp"])) best_epoch = None if history: best_epoch = min(history, key=lambda r: float(r.get("val_total_loss", np.inf))).get("epoch") summary = { "data": { "split_counts": dataset_balance.get("counts", {}), "bifurcation_positive_rate": dataset_balance.get("bifurcation_positive_rate", {}), "lumen_coverage_rate": dataset_balance.get("lumen_coverage_rate", {}), }, "multitask": { "epochs_completed": len(history), "best_epoch_by_val_total_loss": best_epoch, "best_val_total_loss": multitask_summary.get("best_val_total_loss"), "test_seg_metrics": seg, "test_cls_metrics": cls, "cls_accuracy_wilson95": acc_ci, "cls_precision_wilson95": prec_ci, "cls_recall_wilson95": rec_ci, "threshold": multitask_test.get("threshold_info", {}).get("selected_threshold"), }, "lumen_finetune": { "num_samples": lumen_summary.get("num_samples"), "test_metrics": lumen_summary.get("final_test_metrics", {}), }, "standalone_bifurcation": { "selected_threshold": standalone_test.get("selected_threshold"), "test_metrics": standalone_test.get("metrics", {}), }, "runtime_threshold": threshold_payload.get("selected_threshold"), } with (ASSET_DIR / "memo_summary.json").open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2) print("Saved assets to", ASSET_DIR) for p in sorted(ASSET_DIR.glob("*")): print(p.name)