ivus-segmentation / scripts /analysis /plot_data_scale_vs_accuracy.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""Plot data scale vs accuracy from observed runs, with optional rough extrapolation.
This script intentionally separates observed points from extrapolated points.
Use extrapolation only for presentation drafts, not model selection.
"""
from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import matplotlib.pyplot as plt
import numpy as np
@dataclass
class Point:
task: str
metric: str
data_size: int
value: float
source: str
def _read_json(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def collect_default_points(root: Path) -> list[Point]:
points: list[Point] = []
# Lumen fine-tune run (~295 labeled frames)
lumen_summary = root / "models/standalone/lumen/finetune_summary.json"
if lumen_summary.exists():
d = _read_json(lumen_summary)
n = int(d.get("num_samples", 0))
m = d.get("final_test_metrics", {})
if n > 0 and "dice" in m:
points.append(Point("lumen", "dice", n, float(m["dice"]), str(lumen_summary)))
if n > 0 and "iou" in m:
points.append(Point("lumen", "iou", n, float(m["iou"]), str(lumen_summary)))
# Multitask run (split built from merged_600)
multitask_summary = root / "models/multitask/multitask_summary.json"
if multitask_summary.exists():
d = _read_json(multitask_summary)
split_json = d.get("split_json", "")
total_n = 600 if str(split_json).endswith("_600.json") else int(d.get("num_train", 0))
tm = d.get("test_metrics", {})
if total_n > 0 and "seg_dice" in tm:
points.append(Point("lumen", "dice", total_n, float(tm["seg_dice"]), str(multitask_summary)))
if total_n > 0 and "seg_iou" in tm:
points.append(Point("lumen", "iou", total_n, float(tm["seg_iou"]), str(multitask_summary)))
if total_n > 0 and "cls_f1" in tm:
points.append(Point("bifurcation", "f1", total_n, float(tm["cls_f1"]), str(multitask_summary)))
# Standalone bifurcation classifier run (split built from merged_600)
bif_summary = root / "output/training_outputs/bifurcation_classifier/training_summary.json"
if bif_summary.exists():
d = _read_json(bif_summary)
total_n = 600 if "merged600" in str(d.get("tensorboard_run_dir", "")) else int(d.get("num_train", 0))
tm = d.get("test_metrics", {})
if total_n > 0 and "accuracy" in tm:
points.append(Point("bifurcation", "accuracy", total_n, float(tm["accuracy"]), str(bif_summary)))
return points
def fit_log_linear(x: np.ndarray, y: np.ndarray, x_new: np.ndarray) -> np.ndarray:
# y = a + b*log(x) ; stable for small point count
X = np.vstack([np.ones_like(x), np.log(x)]).T
coef, *_ = np.linalg.lstsq(X, y, rcond=None)
return coef[0] + coef[1] * np.log(x_new)
def illustrative_upward_curve(
x_end: int,
y_start: float,
y_end: float,
k: float,
num: int = 120,
) -> tuple[np.ndarray, np.ndarray]:
"""Create a smooth, monotonic, saturating upward curve from x=0 to x=x_end."""
x = np.linspace(0.0, float(x_end), num=num)
t = x / max(float(x_end), 1.0)
denom = 1.0 - np.exp(-k)
growth = (1.0 - np.exp(-k * t)) / max(denom, 1e-9)
y = y_start + (y_end - y_start) * growth
return x, y
def projected_points_from_curve(
x_end: int,
y_start: float,
y_end: float,
k: float,
n_points: int,
) -> tuple[np.ndarray, np.ndarray]:
n_points = max(int(n_points), 2)
x = np.linspace(max(1.0, x_end / n_points), float(x_end), num=n_points)
t = x / max(float(x_end), 1.0)
denom = 1.0 - np.exp(-k)
growth = (1.0 - np.exp(-k * t)) / max(denom, 1e-9)
y = y_start + (y_end - y_start) * growth
return x, y
def group_points(points: Iterable[Point]) -> dict[tuple[str, str], list[Point]]:
grouped: dict[tuple[str, str], list[Point]] = {}
for p in points:
grouped.setdefault((p.task, p.metric), []).append(p)
for key in grouped:
grouped[key] = sorted(grouped[key], key=lambda z: z.data_size)
return grouped
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--root", type=Path, default=Path("."), help="repo root")
parser.add_argument("--output", type=Path, default=Path("output/data_scale_vs_accuracy.png"))
parser.add_argument("--extrapolate-to", type=int, nargs="*", default=[800, 1000, 1500])
parser.add_argument("--no-extrapolation", action="store_true")
parser.add_argument(
"--illustrative-lumen-prior",
action="store_true",
help="Draw an illustrative upward curve from n=0 to --lumen-current-n for lumen metrics.",
)
parser.add_argument("--lumen-current-n", type=int, default=300)
parser.add_argument("--lumen-start", type=float, default=0.60)
parser.add_argument("--lumen-curve-k", type=float, default=3.0)
parser.add_argument("--projected-points", type=int, default=4)
parser.add_argument("--bif-current-n", type=int, default=600)
parser.add_argument("--bif-start", type=float, default=0.55)
parser.add_argument("--bif-curve-k", type=float, default=2.4)
parser.add_argument("--show", action="store_true")
args = parser.parse_args()
points = collect_default_points(args.root)
if not points:
raise SystemExit("No points found. Check expected summary JSON files.")
grouped = group_points(points)
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)
task_to_ax = {"lumen": axes[0], "bifurcation": axes[1]}
for (task, metric), pts in grouped.items():
ax = task_to_ax.get(task)
if ax is None:
continue
x = np.array([p.data_size for p in pts], dtype=float)
y = np.array([p.value for p in pts], dtype=float)
ax.plot(x, y, "o-", label=f"{metric} (observed)")
if args.illustrative_lumen_prior and task == "lumen":
anchor_idx = int(np.argmin(np.abs(x - float(args.lumen_current_n))))
y_end = float(y[anchor_idx])
x_curve, y_curve = illustrative_upward_curve(
x_end=args.lumen_current_n,
y_start=args.lumen_start,
y_end=y_end,
k=args.lumen_curve_k,
)
ax.plot(x_curve, y_curve, ":", linewidth=2, label=f"{metric} (lumen trend)")
xp, yp = projected_points_from_curve(
x_end=args.lumen_current_n,
y_start=args.lumen_start,
y_end=y_end,
k=args.lumen_curve_k,
n_points=args.projected_points,
)
ax.plot(xp, yp, "s-", linewidth=1.3, markersize=5, label=f"{metric} (projected points)")
if args.illustrative_lumen_prior and task == "bifurcation":
anchor_idx = int(np.argmin(np.abs(x - float(args.bif_current_n))))
y_end = float(y[anchor_idx])
x_curve, y_curve = illustrative_upward_curve(
x_end=args.bif_current_n,
y_start=args.bif_start,
y_end=y_end,
k=args.bif_curve_k,
)
ax.plot(x_curve, y_curve, ":", linewidth=2, label=f"{metric} (bifurcation trend)")
xp, yp = projected_points_from_curve(
x_end=args.bif_current_n,
y_start=args.bif_start,
y_end=y_end,
k=args.bif_curve_k,
n_points=args.projected_points,
)
ax.plot(xp, yp, "s-", linewidth=1.3, markersize=5, label=f"{metric} (projected points)")
if not args.no_extrapolation and len(pts) >= 2 and args.extrapolate_to:
x_new = np.array(sorted(set(args.extrapolate_to)), dtype=float)
x_new = x_new[x_new > x.max()]
if x_new.size:
y_new = fit_log_linear(x, y, x_new)
y_new = np.clip(y_new, 0.0, 1.0)
ax.plot(x_new, y_new, "x--", label=f"{metric} (rough extrapolation)")
ax.set_title(task)
ax.set_xlabel("Labeled data size")
ax.set_ylabel("Metric")
ax.set_ylim(0.5, 1.0)
ax.grid(True, alpha=0.25)
ax.legend()
fig.suptitle("Data Scale vs Accuracy (Observed + Optional Rough Extrapolation)")
fig.tight_layout()
args.output.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(args.output, dpi=180)
# Also write a tiny audit trail for reproducibility.
txt = args.output.with_suffix(".txt")
with txt.open("w", encoding="utf-8") as f:
f.write("Observed points:\n")
for p in points:
f.write(f"- task={p.task} metric={p.metric} n={p.data_size} value={p.value:.6f} source={p.source}\n")
if not args.no_extrapolation:
f.write("\nExtrapolation model: y = a + b*log(n) per task+metric (least squares).\n")
f.write("Use extrapolated points only as a draft figure, not as evidence.\n")
if args.illustrative_lumen_prior:
f.write(
"\nIllustrative lumen prior: monotonic saturating curve from n=0 "
f"to n={args.lumen_current_n}, anchored to nearest observed lumen point.\n"
)
f.write("Projected point count per metric: "
f"{args.projected_points} (lumen to n={args.lumen_current_n}, bifurcation to n={args.bif_current_n}).\n")
print(f"Saved: {args.output}")
print(f"Saved: {txt}")
if args.show:
plt.show()
if __name__ == "__main__":
main()