#!/usr/bin/env python """ Focused hyperparameter sweep for the thyroid ResNet-18 classifier, optimized for **validation AUROC**. Each trial is a full train.py run (logged to Trackio). Trial val AUROCs are collected; the best config is reported. Search dimensions (kept deliberately small to avoid overfitting the val set): - backbone variant (torchvision, timm a1/a2) - learning rate - weight decay - batch size - augmentation policy - class-imbalance strategy - fine-tune depth (freeze_stage) - loss (bce vs focal) The sweep is staged: a base set of one-factor-at-a-time trials around a sensible center (informed by literature), so each dimension is isolated. """ import argparse import json import subprocess import sys from pathlib import Path DATA_DIR = "/app/TN5000" SPACE = "Johnyquest7/Trakio_agentic_thyroid" DSET = "Johnyquest7/Trakio_agentic_thyroid_dataset" PROJECT = "agentic_thyroid_resnet18" # center config (literature-informed) CENTER = dict(backbone="timm:resnet18.a1_in1k", lr=2e-4, weight_decay=1e-4, batch_size=32, aug_policy="medical_default", imbalance="pos_weight", freeze_stage=0, loss="bce", optimizer="adamw", scheduler="cosine", epochs=40, early_stop_patience=8, dropout=0.0) # Each trial = (name, overrides-dict). One-factor-at-a-time around CENTER. TRIALS = [ ("c00_center_a1", {}), ("c01_backbone_torchvision", {"backbone": "torchvision"}), ("c02_backbone_a2", {"backbone": "timm:resnet18.a2_in1k"}), ("c03_lr_1e-4", {"lr": 1e-4}), ("c04_lr_5e-4", {"lr": 5e-4}), ("c05_wd_1e-3", {"weight_decay": 1e-3}), ("c06_bs_64", {"batch_size": 64}), ("c07_aug_flip_only", {"aug_policy": "flip_only"}), ("c08_aug_strong", {"aug_policy": "medical_strong"}), ("c09_imb_none", {"imbalance": "none"}), ("c10_imb_sampler", {"imbalance": "sampler"}), ("c11_freeze1", {"freeze_stage": 1}), ("c12_loss_focal", {"loss": "focal", "focal_gamma": 1.0, "imbalance": "none"}), ("c13_lr1e-4_wd1e-3_drop", {"lr": 1e-4, "weight_decay": 1e-3, "dropout": 0.2}), ] def run_trial(name, overrides, out_root, seed): cfg = dict(CENTER); cfg.update(overrides) out_dir = str(Path(out_root) / name) cmd = [sys.executable, "train.py", "--data_dir", DATA_DIR, "--output_dir", out_dir, "--run_name", name, "--seed", str(seed), "--trackio_project", PROJECT, "--trackio_space_id", SPACE, "--trackio_dataset_id", DSET, "--num_workers", "4"] for k, v in cfg.items(): if isinstance(v, bool): if v: cmd.append(f"--{k}") else: cmd += [f"--{k}", str(v)] print(f"\n===== TRIAL {name} =====\n{' '.join(cmd)}", flush=True) r = subprocess.run(cmd, capture_output=True, text=True) # echo tail for visibility print(r.stdout[-1500:], flush=True) if r.returncode != 0: print("STDERR tail:", r.stderr[-1500:], flush=True) summ_path = Path(out_dir) / "final_summary.json" best = -1.0; best_epoch = -1 if summ_path.exists(): s = json.load(open(summ_path)) best = s.get("best_val_auroc", -1.0); best_epoch = s.get("best_epoch", -1) return {"name": name, "config": cfg, "best_val_auroc": best, "best_epoch": best_epoch, "out_dir": out_dir, "returncode": r.returncode} def main(): ap = argparse.ArgumentParser() ap.add_argument("--out_root", default="/app/sweep_runs") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--only", default=None, help="comma list of trial names to run") args = ap.parse_args() Path(args.out_root).mkdir(parents=True, exist_ok=True) trials = TRIALS if args.only: names = set(args.only.split(",")) trials = [t for t in TRIALS if t[0] in names] results = [] res_path = Path(args.out_root) / "sweep_results.json" if res_path.exists(): results = json.load(open(res_path)) done = {r["name"] for r in results} trials = [t for t in trials if t[0] not in done] for name, ov in trials: res = run_trial(name, ov, args.out_root, args.seed) results.append(res) json.dump(results, open(res_path, "w"), indent=2) print(f">>> {name}: val_auroc={res['best_val_auroc']:.4f} rc={res['returncode']}", flush=True) results_sorted = sorted(results, key=lambda r: r["best_val_auroc"], reverse=True) print("\n===== SWEEP LEADERBOARD =====") for r in results_sorted: print(f"{r['best_val_auroc']:.4f} {r['name']:28s} (epoch {r['best_epoch']}, rc {r['returncode']})") json.dump(results_sorted, open(Path(args.out_root) / "sweep_leaderboard.json", "w"), indent=2) if results_sorted: print("\nBEST:", results_sorted[0]["name"], results_sorted[0]["best_val_auroc"]) if __name__ == "__main__": main()