Johnyquest7's picture
Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified
Raw
History Blame Contribute Delete
5.01 kB
#!/usr/bin/env python
"""
Focused hyperparameter sweep for the thyroid ResNet-18 classifier, optimized
for **validation AUROC**. Each trial is a full train.py run (logged to Trackio).
Trial val AUROCs are collected; the best config is reported.
Search dimensions (kept deliberately small to avoid overfitting the val set):
- backbone variant (torchvision, timm a1/a2)
- learning rate
- weight decay
- batch size
- augmentation policy
- class-imbalance strategy
- fine-tune depth (freeze_stage)
- loss (bce vs focal)
The sweep is staged: a base set of one-factor-at-a-time trials around a sensible
center (informed by literature), so each dimension is isolated.
"""
import argparse
import json
import subprocess
import sys
from pathlib import Path
DATA_DIR = "/app/TN5000"
SPACE = "Johnyquest7/Trakio_agentic_thyroid"
DSET = "Johnyquest7/Trakio_agentic_thyroid_dataset"
PROJECT = "agentic_thyroid_resnet18"
# center config (literature-informed)
CENTER = dict(backbone="timm:resnet18.a1_in1k", lr=2e-4, weight_decay=1e-4,
batch_size=32, aug_policy="medical_default", imbalance="pos_weight",
freeze_stage=0, loss="bce", optimizer="adamw", scheduler="cosine",
epochs=40, early_stop_patience=8, dropout=0.0)
# Each trial = (name, overrides-dict). One-factor-at-a-time around CENTER.
TRIALS = [
("c00_center_a1", {}),
("c01_backbone_torchvision", {"backbone": "torchvision"}),
("c02_backbone_a2", {"backbone": "timm:resnet18.a2_in1k"}),
("c03_lr_1e-4", {"lr": 1e-4}),
("c04_lr_5e-4", {"lr": 5e-4}),
("c05_wd_1e-3", {"weight_decay": 1e-3}),
("c06_bs_64", {"batch_size": 64}),
("c07_aug_flip_only", {"aug_policy": "flip_only"}),
("c08_aug_strong", {"aug_policy": "medical_strong"}),
("c09_imb_none", {"imbalance": "none"}),
("c10_imb_sampler", {"imbalance": "sampler"}),
("c11_freeze1", {"freeze_stage": 1}),
("c12_loss_focal", {"loss": "focal", "focal_gamma": 1.0, "imbalance": "none"}),
("c13_lr1e-4_wd1e-3_drop", {"lr": 1e-4, "weight_decay": 1e-3, "dropout": 0.2}),
]
def run_trial(name, overrides, out_root, seed):
cfg = dict(CENTER); cfg.update(overrides)
out_dir = str(Path(out_root) / name)
cmd = [sys.executable, "train.py",
"--data_dir", DATA_DIR, "--output_dir", out_dir,
"--run_name", name, "--seed", str(seed),
"--trackio_project", PROJECT, "--trackio_space_id", SPACE,
"--trackio_dataset_id", DSET,
"--num_workers", "4"]
for k, v in cfg.items():
if isinstance(v, bool):
if v:
cmd.append(f"--{k}")
else:
cmd += [f"--{k}", str(v)]
print(f"\n===== TRIAL {name} =====\n{' '.join(cmd)}", flush=True)
r = subprocess.run(cmd, capture_output=True, text=True)
# echo tail for visibility
print(r.stdout[-1500:], flush=True)
if r.returncode != 0:
print("STDERR tail:", r.stderr[-1500:], flush=True)
summ_path = Path(out_dir) / "final_summary.json"
best = -1.0; best_epoch = -1
if summ_path.exists():
s = json.load(open(summ_path))
best = s.get("best_val_auroc", -1.0); best_epoch = s.get("best_epoch", -1)
return {"name": name, "config": cfg, "best_val_auroc": best,
"best_epoch": best_epoch, "out_dir": out_dir, "returncode": r.returncode}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out_root", default="/app/sweep_runs")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--only", default=None, help="comma list of trial names to run")
args = ap.parse_args()
Path(args.out_root).mkdir(parents=True, exist_ok=True)
trials = TRIALS
if args.only:
names = set(args.only.split(","))
trials = [t for t in TRIALS if t[0] in names]
results = []
res_path = Path(args.out_root) / "sweep_results.json"
if res_path.exists():
results = json.load(open(res_path))
done = {r["name"] for r in results}
trials = [t for t in trials if t[0] not in done]
for name, ov in trials:
res = run_trial(name, ov, args.out_root, args.seed)
results.append(res)
json.dump(results, open(res_path, "w"), indent=2)
print(f">>> {name}: val_auroc={res['best_val_auroc']:.4f} rc={res['returncode']}", flush=True)
results_sorted = sorted(results, key=lambda r: r["best_val_auroc"], reverse=True)
print("\n===== SWEEP LEADERBOARD =====")
for r in results_sorted:
print(f"{r['best_val_auroc']:.4f} {r['name']:28s} (epoch {r['best_epoch']}, rc {r['returncode']})")
json.dump(results_sorted, open(Path(args.out_root) / "sweep_leaderboard.json", "w"), indent=2)
if results_sorted:
print("\nBEST:", results_sorted[0]["name"], results_sorted[0]["best_val_auroc"])
if __name__ == "__main__":
main()