#!/usr/bin/env python """ Reproducible training for the ResNet-18 thyroid ultrasound malignancy classifier. Trains on Train only; selects the best checkpoint by **validation AUROC**. Logs everything to Trackio (losses, val AUROC/sens/spec/PPV/NPV/ECE/Brier, LR, epoch, hyperparameters, env info) and emits trackio.alert() at decision points. Single-command reproduction: python train.py --config configs/final_config.yaml All CLI args override config values. The exact command line, resolved config, seed, package versions and hardware info are saved to the output dir. Dataset is loaded directly from the Train/Valid/Test FOLDER structure of the Hub repo (NOT the flattened datasets-viewer 'train' split), so the predefined splits are respected. """ import argparse import os import sys import time from pathlib import Path import numpy as np import yaml import thyroid_lib as L def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--config", default=None, help="YAML config path") ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification") ap.add_argument("--data_dir", default=None, help="Local TN5000 dir; else downloaded from Hub") ap.add_argument("--output_dir", default="run_out") ap.add_argument("--backbone", default="timm:resnet18.a1_in1k") ap.add_argument("--freeze_stage", type=int, default=0) ap.add_argument("--dropout", type=float, default=0.0) ap.add_argument("--aug_policy", default="medical_default") ap.add_argument("--loss", default="bce", choices=["bce", "focal"]) ap.add_argument("--focal_gamma", type=float, default=2.0) ap.add_argument("--focal_alpha", type=float, default=0.5) ap.add_argument("--imbalance", default="pos_weight", choices=["pos_weight", "none", "sampler"]) ap.add_argument("--optimizer", default="adamw", choices=["adamw", "sgd"]) ap.add_argument("--lr", type=float, default=2e-4) ap.add_argument("--weight_decay", type=float, default=1e-4) ap.add_argument("--batch_size", type=int, default=32) ap.add_argument("--epochs", type=int, default=40) ap.add_argument("--scheduler", default="cosine", choices=["cosine", "plateau", "none"]) ap.add_argument("--warmup_epochs", type=int, default=2) ap.add_argument("--early_stop_patience", type=int, default=8) ap.add_argument("--amp", action="store_true", default=True) ap.add_argument("--no_amp", dest="amp", action="store_false") ap.add_argument("--num_workers", type=int, default=4) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--strict_determinism", action="store_true", default=True) ap.add_argument("--no_strict_determinism", dest="strict_determinism", action="store_false") ap.add_argument("--trackio_project", default="agentic_thyroid_resnet18") ap.add_argument("--trackio_space_id", default="Johnyquest7/Trakio_agentic_thyroid") ap.add_argument("--trackio_dataset_id", default="Johnyquest7/Trakio_agentic_thyroid_dataset") ap.add_argument("--run_name", default=None) ap.add_argument("--no_trackio", action="store_true") return ap.parse_args() def merge_config(args): if args.config and Path(args.config).exists(): with open(args.config) as f: cfg = yaml.safe_load(f) or {} passed = {a.split("=")[0].lstrip("-").replace("-", "_") for a in sys.argv[1:] if a.startswith("--")} for k, v in cfg.items(): if k not in passed and hasattr(args, k): setattr(args, k, v) return args def main(): args = parse_args() args = merge_config(args) out = Path(args.output_dir) out.mkdir(parents=True, exist_ok=True) import torch from torch.utils.data import DataLoader, WeightedRandomSampler from sklearn.metrics import roc_auc_score import torch.nn.functional as F L.set_determinism(args.seed, strict=args.strict_determinism) env = L.collect_env_info() device = "cuda" if torch.cuda.is_available() else "cpu" if args.data_dir: data_dir = Path(args.data_dir) else: from huggingface_hub import snapshot_download data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset", local_dir=str(out / "_data"), allow_patterns=["Train/**", "Valid/**", "Test/**"])) model, pp = L.build_model(args.backbone, freeze_stage=args.freeze_stage, dropout=args.dropout) model = model.to(device) train_tf = L.build_train_transform(pp, args.aug_policy) eval_tf = L.build_eval_transform(pp) train_ds = L.ThyroidImageFolder(data_dir / "Train", train_tf) valid_ds = L.ThyroidImageFolder(data_dir / "Valid", eval_tf) n_neg, n_pos = L.class_counts(train_ds.targets) pos_weight = (n_neg / n_pos) if (args.imbalance == "pos_weight" and n_pos) else None g = torch.Generator(); g.manual_seed(args.seed) if args.imbalance == "sampler": cw = np.array([1.0 / n_neg, 1.0 / n_pos]) sw = np.array([cw[t] for t in train_ds.targets]) sampler = WeightedRandomSampler(torch.tensor(sw, dtype=torch.double), num_samples=len(sw), replacement=True, generator=g) train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, worker_init_fn=L.seed_worker, generator=g, pin_memory=(device == "cuda"), drop_last=False) else: train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=L.seed_worker, generator=g, pin_memory=(device == "cuda"), drop_last=False) valid_loader = DataLoader(valid_ds, batch_size=64, shuffle=False, num_workers=args.num_workers, pin_memory=(device == "cuda")) criterion = L.build_loss(args.loss, pos_weight if args.loss == "bce" else None, args.focal_gamma, args.focal_alpha).to(device) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == "adamw": optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) if args.scheduler == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) elif args.scheduler == "plateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3) else: scheduler = None scaler = torch.amp.GradScaler("cuda", enabled=(args.amp and device == "cuda")) run_name = args.run_name or ( f"{args.backbone.replace('timm:','tm_').replace('.','_')}_lr{args.lr}_wd{args.weight_decay}" f"_bs{args.batch_size}_{args.aug_policy}_{args.loss}_{args.imbalance}_fz{args.freeze_stage}") resolved = vars(args).copy() resolved.update({"pos_weight": pos_weight, "n_train_neg": n_neg, "n_train_pos": n_pos, "preprocess": pp.to_dict(), "device": device}) L.save_json(resolved, out / "resolved_config.json") L.save_json(env, out / "env_info.json") with open(out / "command_line.txt", "w") as f: f.write("python " + " ".join(sys.argv) + "\n") with open(out / "config_used.yaml", "w") as f: yaml.safe_dump({k: v for k, v in resolved.items() if not isinstance(v, dict) or k == "preprocess"}, f) use_trackio = not args.no_trackio if use_trackio: import trackio try: from trackio.alerts import AlertLevel _LV = {"info": AlertLevel.INFO, "warn": AlertLevel.WARN, "error": AlertLevel.ERROR} except Exception: _LV = {"info": "info", "warn": "warn", "error": "error"} def _alert(title, text, level="info"): try: trackio.alert(title, text, level=_LV.get(level, level)) except Exception as e: print(f"[alert-failed] {title}: {text} ({e})", flush=True) trackio.init(project=args.trackio_project, name=run_name, space_id=args.trackio_space_id, dataset_id=args.trackio_dataset_id, config={k: v for k, v in resolved.items() if k != "preprocess"}) _alert("Run started", f"{run_name} | backbone={args.backbone} loss={args.loss} " f"imb={args.imbalance} lr={args.lr} wd={args.weight_decay} " f"bs={args.batch_size} aug={args.aug_policy} fz={args.freeze_stage} " f"pos_weight={pos_weight} device={env.get('gpu_name')}", "info") best_auroc = -1.0 best_epoch = -1 epochs_no_improve = 0 history = [] global_step = 0 n_warmup_steps = args.warmup_epochs * max(1, len(train_loader)) base_lr = args.lr for epoch in range(args.epochs): model.train() t0 = time.time() running = 0.0 for x, y, _ in train_loader: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True).float() if global_step < n_warmup_steps and args.warmup_epochs > 0: for pg in optimizer.param_groups: pg["lr"] = base_lr * (global_step + 1) / n_warmup_steps optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(args.amp and device == "cuda")): out_logits = model(x).view(-1) loss = criterion(out_logits, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() running += loss.item() * x.size(0) global_step += 1 train_loss = running / len(train_ds) val_logits, val_labels, _ = L.collect_logits(model, valid_loader, device, amp=args.amp) val_probs = L.sigmoid(val_logits) val_auroc = float(roc_auc_score(val_labels, val_probs)) val_loss = float(F.binary_cross_entropy_with_logits( torch.tensor(val_logits), torch.tensor(val_labels, dtype=torch.float32)).item()) m = L.point_metrics(val_labels, val_probs, 0.5) cur_lr = optimizer.param_groups[0]["lr"] if scheduler is not None: if args.scheduler == "plateau": scheduler.step(val_auroc) elif global_step >= n_warmup_steps: scheduler.step() row = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_auroc": val_auroc, "val_sens@0.5": m["sensitivity"], "val_spec@0.5": m["specificity"], "val_ppv@0.5": m["ppv"], "val_npv@0.5": m["npv"], "val_ece@0.5": m["ece"], "val_brier": m["brier"], "lr": cur_lr, "epoch_time_s": round(time.time() - t0, 1)} history.append(row) print(f"[epoch {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f} " f"val_auroc={val_auroc:.4f} lr={cur_lr:.2e} ({row['epoch_time_s']}s)", flush=True) if use_trackio: trackio.log({"train_loss": train_loss, "val_loss": val_loss, "val_auroc": val_auroc, "val_sensitivity": m["sensitivity"], "val_specificity": m["specificity"], "val_ppv": m["ppv"], "val_npv": m["npv"], "val_ece": m["ece"], "val_brier": m["brier"], "lr": cur_lr, "epoch": epoch}) improved = val_auroc > best_auroc + 1e-5 if improved: best_auroc = val_auroc best_epoch = epoch epochs_no_improve = 0 torch.save({"model_state": model.state_dict(), "backbone": args.backbone, "freeze_stage": args.freeze_stage, "dropout": args.dropout, "preprocess": pp.to_dict(), "epoch": epoch, "val_auroc": val_auroc}, out / "best_model.pt") L.save_json({"best_epoch": best_epoch, "best_val_auroc": best_auroc, "val_metrics_at_0.5": m}, out / "best_val_summary.json") else: epochs_no_improve += 1 if use_trackio and (not np.isfinite(train_loss) or train_loss > 1e3): _alert("Training diverged", f"train_loss={train_loss} at epoch {epoch} — lr likely too high, try x0.1", "error") if epochs_no_improve >= args.early_stop_patience: print(f"Early stopping at epoch {epoch} (no val AUROC improvement for " f"{args.early_stop_patience} epochs).", flush=True) if use_trackio: _alert("Early stopping", f"No val AUROC gain for {args.early_stop_patience} epochs; " f"best={best_auroc:.4f} @ epoch {best_epoch}. Consider lr x0.5.", "warn") break L.save_json(history, out / "history.json") L.save_json({"best_val_auroc": best_auroc, "best_epoch": best_epoch, "run_name": run_name, "backbone": args.backbone}, out / "final_summary.json") if use_trackio: trackio.log({"best_val_auroc": best_auroc, "best_epoch": best_epoch}) _alert("Run complete", f"{run_name}: best val AUROC={best_auroc:.4f} @ epoch {best_epoch}", "info") trackio.finish() print(f"DONE best_val_auroc={best_auroc:.4f} best_epoch={best_epoch}", flush=True) return best_auroc if __name__ == "__main__": main()