Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified | #!/usr/bin/env python | |
| """ | |
| Reproducible training for the ResNet-18 thyroid ultrasound malignancy classifier. | |
| Trains on Train only; selects the best checkpoint by **validation AUROC**. | |
| Logs everything to Trackio (losses, val AUROC/sens/spec/PPV/NPV/ECE/Brier, LR, | |
| epoch, hyperparameters, env info) and emits trackio.alert() at decision points. | |
| Single-command reproduction: | |
| python train.py --config configs/final_config.yaml | |
| All CLI args override config values. The exact command line, resolved config, | |
| seed, package versions and hardware info are saved to the output dir. | |
| Dataset is loaded directly from the Train/Valid/Test FOLDER structure of the | |
| Hub repo (NOT the flattened datasets-viewer 'train' split), so the predefined | |
| splits are respected. | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import yaml | |
| import thyroid_lib as L | |
| def parse_args(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--config", default=None, help="YAML config path") | |
| ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification") | |
| ap.add_argument("--data_dir", default=None, help="Local TN5000 dir; else downloaded from Hub") | |
| ap.add_argument("--output_dir", default="run_out") | |
| ap.add_argument("--backbone", default="timm:resnet18.a1_in1k") | |
| ap.add_argument("--freeze_stage", type=int, default=0) | |
| ap.add_argument("--dropout", type=float, default=0.0) | |
| ap.add_argument("--aug_policy", default="medical_default") | |
| ap.add_argument("--loss", default="bce", choices=["bce", "focal"]) | |
| ap.add_argument("--focal_gamma", type=float, default=2.0) | |
| ap.add_argument("--focal_alpha", type=float, default=0.5) | |
| ap.add_argument("--imbalance", default="pos_weight", | |
| choices=["pos_weight", "none", "sampler"]) | |
| ap.add_argument("--optimizer", default="adamw", choices=["adamw", "sgd"]) | |
| ap.add_argument("--lr", type=float, default=2e-4) | |
| ap.add_argument("--weight_decay", type=float, default=1e-4) | |
| ap.add_argument("--batch_size", type=int, default=32) | |
| ap.add_argument("--epochs", type=int, default=40) | |
| ap.add_argument("--scheduler", default="cosine", choices=["cosine", "plateau", "none"]) | |
| ap.add_argument("--warmup_epochs", type=int, default=2) | |
| ap.add_argument("--early_stop_patience", type=int, default=8) | |
| ap.add_argument("--amp", action="store_true", default=True) | |
| ap.add_argument("--no_amp", dest="amp", action="store_false") | |
| ap.add_argument("--num_workers", type=int, default=4) | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--strict_determinism", action="store_true", default=True) | |
| ap.add_argument("--no_strict_determinism", dest="strict_determinism", action="store_false") | |
| ap.add_argument("--trackio_project", default="agentic_thyroid_resnet18") | |
| ap.add_argument("--trackio_space_id", default="Johnyquest7/Trakio_agentic_thyroid") | |
| ap.add_argument("--trackio_dataset_id", default="Johnyquest7/Trakio_agentic_thyroid_dataset") | |
| ap.add_argument("--run_name", default=None) | |
| ap.add_argument("--no_trackio", action="store_true") | |
| return ap.parse_args() | |
| def merge_config(args): | |
| if args.config and Path(args.config).exists(): | |
| with open(args.config) as f: | |
| cfg = yaml.safe_load(f) or {} | |
| passed = {a.split("=")[0].lstrip("-").replace("-", "_") for a in sys.argv[1:] if a.startswith("--")} | |
| for k, v in cfg.items(): | |
| if k not in passed and hasattr(args, k): | |
| setattr(args, k, v) | |
| return args | |
| def main(): | |
| args = parse_args() | |
| args = merge_config(args) | |
| out = Path(args.output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| import torch | |
| from torch.utils.data import DataLoader, WeightedRandomSampler | |
| from sklearn.metrics import roc_auc_score | |
| import torch.nn.functional as F | |
| L.set_determinism(args.seed, strict=args.strict_determinism) | |
| env = L.collect_env_info() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if args.data_dir: | |
| data_dir = Path(args.data_dir) | |
| else: | |
| from huggingface_hub import snapshot_download | |
| data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset", | |
| local_dir=str(out / "_data"), | |
| allow_patterns=["Train/**", "Valid/**", "Test/**"])) | |
| model, pp = L.build_model(args.backbone, freeze_stage=args.freeze_stage, dropout=args.dropout) | |
| model = model.to(device) | |
| train_tf = L.build_train_transform(pp, args.aug_policy) | |
| eval_tf = L.build_eval_transform(pp) | |
| train_ds = L.ThyroidImageFolder(data_dir / "Train", train_tf) | |
| valid_ds = L.ThyroidImageFolder(data_dir / "Valid", eval_tf) | |
| n_neg, n_pos = L.class_counts(train_ds.targets) | |
| pos_weight = (n_neg / n_pos) if (args.imbalance == "pos_weight" and n_pos) else None | |
| g = torch.Generator(); g.manual_seed(args.seed) | |
| if args.imbalance == "sampler": | |
| cw = np.array([1.0 / n_neg, 1.0 / n_pos]) | |
| sw = np.array([cw[t] for t in train_ds.targets]) | |
| sampler = WeightedRandomSampler(torch.tensor(sw, dtype=torch.double), | |
| num_samples=len(sw), replacement=True, generator=g) | |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler, | |
| num_workers=args.num_workers, worker_init_fn=L.seed_worker, | |
| generator=g, pin_memory=(device == "cuda"), drop_last=False) | |
| else: | |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, | |
| num_workers=args.num_workers, worker_init_fn=L.seed_worker, | |
| generator=g, pin_memory=(device == "cuda"), drop_last=False) | |
| valid_loader = DataLoader(valid_ds, batch_size=64, shuffle=False, | |
| num_workers=args.num_workers, pin_memory=(device == "cuda")) | |
| criterion = L.build_loss(args.loss, pos_weight if args.loss == "bce" else None, | |
| args.focal_gamma, args.focal_alpha).to(device) | |
| params = [p for p in model.parameters() if p.requires_grad] | |
| if args.optimizer == "adamw": | |
| optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) | |
| else: | |
| optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9, | |
| weight_decay=args.weight_decay, nesterov=True) | |
| if args.scheduler == "cosine": | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| elif args.scheduler == "plateau": | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", | |
| factor=0.5, patience=3) | |
| else: | |
| scheduler = None | |
| scaler = torch.amp.GradScaler("cuda", enabled=(args.amp and device == "cuda")) | |
| run_name = args.run_name or ( | |
| f"{args.backbone.replace('timm:','tm_').replace('.','_')}_lr{args.lr}_wd{args.weight_decay}" | |
| f"_bs{args.batch_size}_{args.aug_policy}_{args.loss}_{args.imbalance}_fz{args.freeze_stage}") | |
| resolved = vars(args).copy() | |
| resolved.update({"pos_weight": pos_weight, "n_train_neg": n_neg, "n_train_pos": n_pos, | |
| "preprocess": pp.to_dict(), "device": device}) | |
| L.save_json(resolved, out / "resolved_config.json") | |
| L.save_json(env, out / "env_info.json") | |
| with open(out / "command_line.txt", "w") as f: | |
| f.write("python " + " ".join(sys.argv) + "\n") | |
| with open(out / "config_used.yaml", "w") as f: | |
| yaml.safe_dump({k: v for k, v in resolved.items() if not isinstance(v, dict) or k == "preprocess"}, f) | |
| use_trackio = not args.no_trackio | |
| if use_trackio: | |
| import trackio | |
| try: | |
| from trackio.alerts import AlertLevel | |
| _LV = {"info": AlertLevel.INFO, "warn": AlertLevel.WARN, "error": AlertLevel.ERROR} | |
| except Exception: | |
| _LV = {"info": "info", "warn": "warn", "error": "error"} | |
| def _alert(title, text, level="info"): | |
| try: | |
| trackio.alert(title, text, level=_LV.get(level, level)) | |
| except Exception as e: | |
| print(f"[alert-failed] {title}: {text} ({e})", flush=True) | |
| trackio.init(project=args.trackio_project, name=run_name, | |
| space_id=args.trackio_space_id, | |
| dataset_id=args.trackio_dataset_id, | |
| config={k: v for k, v in resolved.items() if k != "preprocess"}) | |
| _alert("Run started", | |
| f"{run_name} | backbone={args.backbone} loss={args.loss} " | |
| f"imb={args.imbalance} lr={args.lr} wd={args.weight_decay} " | |
| f"bs={args.batch_size} aug={args.aug_policy} fz={args.freeze_stage} " | |
| f"pos_weight={pos_weight} device={env.get('gpu_name')}", | |
| "info") | |
| best_auroc = -1.0 | |
| best_epoch = -1 | |
| epochs_no_improve = 0 | |
| history = [] | |
| global_step = 0 | |
| n_warmup_steps = args.warmup_epochs * max(1, len(train_loader)) | |
| base_lr = args.lr | |
| for epoch in range(args.epochs): | |
| model.train() | |
| t0 = time.time() | |
| running = 0.0 | |
| for x, y, _ in train_loader: | |
| x = x.to(device, non_blocking=True) | |
| y = y.to(device, non_blocking=True).float() | |
| if global_step < n_warmup_steps and args.warmup_epochs > 0: | |
| for pg in optimizer.param_groups: | |
| pg["lr"] = base_lr * (global_step + 1) / n_warmup_steps | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.autocast(device_type="cuda", dtype=torch.float16, | |
| enabled=(args.amp and device == "cuda")): | |
| out_logits = model(x).view(-1) | |
| loss = criterion(out_logits, y) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| running += loss.item() * x.size(0) | |
| global_step += 1 | |
| train_loss = running / len(train_ds) | |
| val_logits, val_labels, _ = L.collect_logits(model, valid_loader, device, amp=args.amp) | |
| val_probs = L.sigmoid(val_logits) | |
| val_auroc = float(roc_auc_score(val_labels, val_probs)) | |
| val_loss = float(F.binary_cross_entropy_with_logits( | |
| torch.tensor(val_logits), torch.tensor(val_labels, dtype=torch.float32)).item()) | |
| m = L.point_metrics(val_labels, val_probs, 0.5) | |
| cur_lr = optimizer.param_groups[0]["lr"] | |
| if scheduler is not None: | |
| if args.scheduler == "plateau": | |
| scheduler.step(val_auroc) | |
| elif global_step >= n_warmup_steps: | |
| scheduler.step() | |
| row = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, | |
| "val_auroc": val_auroc, "val_sens@0.5": m["sensitivity"], | |
| "val_spec@0.5": m["specificity"], "val_ppv@0.5": m["ppv"], | |
| "val_npv@0.5": m["npv"], "val_ece@0.5": m["ece"], "val_brier": m["brier"], | |
| "lr": cur_lr, "epoch_time_s": round(time.time() - t0, 1)} | |
| history.append(row) | |
| print(f"[epoch {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f} " | |
| f"val_auroc={val_auroc:.4f} lr={cur_lr:.2e} ({row['epoch_time_s']}s)", flush=True) | |
| if use_trackio: | |
| trackio.log({"train_loss": train_loss, "val_loss": val_loss, "val_auroc": val_auroc, | |
| "val_sensitivity": m["sensitivity"], "val_specificity": m["specificity"], | |
| "val_ppv": m["ppv"], "val_npv": m["npv"], "val_ece": m["ece"], | |
| "val_brier": m["brier"], "lr": cur_lr, "epoch": epoch}) | |
| improved = val_auroc > best_auroc + 1e-5 | |
| if improved: | |
| best_auroc = val_auroc | |
| best_epoch = epoch | |
| epochs_no_improve = 0 | |
| torch.save({"model_state": model.state_dict(), | |
| "backbone": args.backbone, "freeze_stage": args.freeze_stage, | |
| "dropout": args.dropout, "preprocess": pp.to_dict(), | |
| "epoch": epoch, "val_auroc": val_auroc}, | |
| out / "best_model.pt") | |
| L.save_json({"best_epoch": best_epoch, "best_val_auroc": best_auroc, | |
| "val_metrics_at_0.5": m}, out / "best_val_summary.json") | |
| else: | |
| epochs_no_improve += 1 | |
| if use_trackio and (not np.isfinite(train_loss) or train_loss > 1e3): | |
| _alert("Training diverged", | |
| f"train_loss={train_loss} at epoch {epoch} — lr likely too high, try x0.1", | |
| "error") | |
| if epochs_no_improve >= args.early_stop_patience: | |
| print(f"Early stopping at epoch {epoch} (no val AUROC improvement for " | |
| f"{args.early_stop_patience} epochs).", flush=True) | |
| if use_trackio: | |
| _alert("Early stopping", | |
| f"No val AUROC gain for {args.early_stop_patience} epochs; " | |
| f"best={best_auroc:.4f} @ epoch {best_epoch}. Consider lr x0.5.", | |
| "warn") | |
| break | |
| L.save_json(history, out / "history.json") | |
| L.save_json({"best_val_auroc": best_auroc, "best_epoch": best_epoch, | |
| "run_name": run_name, "backbone": args.backbone}, | |
| out / "final_summary.json") | |
| if use_trackio: | |
| trackio.log({"best_val_auroc": best_auroc, "best_epoch": best_epoch}) | |
| _alert("Run complete", | |
| f"{run_name}: best val AUROC={best_auroc:.4f} @ epoch {best_epoch}", | |
| "info") | |
| trackio.finish() | |
| print(f"DONE best_val_auroc={best_auroc:.4f} best_epoch={best_epoch}", flush=True) | |
| return best_auroc | |
| if __name__ == "__main__": | |
| main() | |