Johnyquest7's picture
Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified
Raw
History Blame Contribute Delete
14 kB
#!/usr/bin/env python
"""
Reproducible training for the ResNet-18 thyroid ultrasound malignancy classifier.
Trains on Train only; selects the best checkpoint by **validation AUROC**.
Logs everything to Trackio (losses, val AUROC/sens/spec/PPV/NPV/ECE/Brier, LR,
epoch, hyperparameters, env info) and emits trackio.alert() at decision points.
Single-command reproduction:
python train.py --config configs/final_config.yaml
All CLI args override config values. The exact command line, resolved config,
seed, package versions and hardware info are saved to the output dir.
Dataset is loaded directly from the Train/Valid/Test FOLDER structure of the
Hub repo (NOT the flattened datasets-viewer 'train' split), so the predefined
splits are respected.
"""
import argparse
import os
import sys
import time
from pathlib import Path
import numpy as np
import yaml
import thyroid_lib as L
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--config", default=None, help="YAML config path")
ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification")
ap.add_argument("--data_dir", default=None, help="Local TN5000 dir; else downloaded from Hub")
ap.add_argument("--output_dir", default="run_out")
ap.add_argument("--backbone", default="timm:resnet18.a1_in1k")
ap.add_argument("--freeze_stage", type=int, default=0)
ap.add_argument("--dropout", type=float, default=0.0)
ap.add_argument("--aug_policy", default="medical_default")
ap.add_argument("--loss", default="bce", choices=["bce", "focal"])
ap.add_argument("--focal_gamma", type=float, default=2.0)
ap.add_argument("--focal_alpha", type=float, default=0.5)
ap.add_argument("--imbalance", default="pos_weight",
choices=["pos_weight", "none", "sampler"])
ap.add_argument("--optimizer", default="adamw", choices=["adamw", "sgd"])
ap.add_argument("--lr", type=float, default=2e-4)
ap.add_argument("--weight_decay", type=float, default=1e-4)
ap.add_argument("--batch_size", type=int, default=32)
ap.add_argument("--epochs", type=int, default=40)
ap.add_argument("--scheduler", default="cosine", choices=["cosine", "plateau", "none"])
ap.add_argument("--warmup_epochs", type=int, default=2)
ap.add_argument("--early_stop_patience", type=int, default=8)
ap.add_argument("--amp", action="store_true", default=True)
ap.add_argument("--no_amp", dest="amp", action="store_false")
ap.add_argument("--num_workers", type=int, default=4)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--strict_determinism", action="store_true", default=True)
ap.add_argument("--no_strict_determinism", dest="strict_determinism", action="store_false")
ap.add_argument("--trackio_project", default="agentic_thyroid_resnet18")
ap.add_argument("--trackio_space_id", default="Johnyquest7/Trakio_agentic_thyroid")
ap.add_argument("--trackio_dataset_id", default="Johnyquest7/Trakio_agentic_thyroid_dataset")
ap.add_argument("--run_name", default=None)
ap.add_argument("--no_trackio", action="store_true")
return ap.parse_args()
def merge_config(args):
if args.config and Path(args.config).exists():
with open(args.config) as f:
cfg = yaml.safe_load(f) or {}
passed = {a.split("=")[0].lstrip("-").replace("-", "_") for a in sys.argv[1:] if a.startswith("--")}
for k, v in cfg.items():
if k not in passed and hasattr(args, k):
setattr(args, k, v)
return args
def main():
args = parse_args()
args = merge_config(args)
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
L.set_determinism(args.seed, strict=args.strict_determinism)
env = L.collect_env_info()
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.data_dir:
data_dir = Path(args.data_dir)
else:
from huggingface_hub import snapshot_download
data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset",
local_dir=str(out / "_data"),
allow_patterns=["Train/**", "Valid/**", "Test/**"]))
model, pp = L.build_model(args.backbone, freeze_stage=args.freeze_stage, dropout=args.dropout)
model = model.to(device)
train_tf = L.build_train_transform(pp, args.aug_policy)
eval_tf = L.build_eval_transform(pp)
train_ds = L.ThyroidImageFolder(data_dir / "Train", train_tf)
valid_ds = L.ThyroidImageFolder(data_dir / "Valid", eval_tf)
n_neg, n_pos = L.class_counts(train_ds.targets)
pos_weight = (n_neg / n_pos) if (args.imbalance == "pos_weight" and n_pos) else None
g = torch.Generator(); g.manual_seed(args.seed)
if args.imbalance == "sampler":
cw = np.array([1.0 / n_neg, 1.0 / n_pos])
sw = np.array([cw[t] for t in train_ds.targets])
sampler = WeightedRandomSampler(torch.tensor(sw, dtype=torch.double),
num_samples=len(sw), replacement=True, generator=g)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler,
num_workers=args.num_workers, worker_init_fn=L.seed_worker,
generator=g, pin_memory=(device == "cuda"), drop_last=False)
else:
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, worker_init_fn=L.seed_worker,
generator=g, pin_memory=(device == "cuda"), drop_last=False)
valid_loader = DataLoader(valid_ds, batch_size=64, shuffle=False,
num_workers=args.num_workers, pin_memory=(device == "cuda"))
criterion = L.build_loss(args.loss, pos_weight if args.loss == "bce" else None,
args.focal_gamma, args.focal_alpha).to(device)
params = [p for p in model.parameters() if p.requires_grad]
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9,
weight_decay=args.weight_decay, nesterov=True)
if args.scheduler == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
elif args.scheduler == "plateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max",
factor=0.5, patience=3)
else:
scheduler = None
scaler = torch.amp.GradScaler("cuda", enabled=(args.amp and device == "cuda"))
run_name = args.run_name or (
f"{args.backbone.replace('timm:','tm_').replace('.','_')}_lr{args.lr}_wd{args.weight_decay}"
f"_bs{args.batch_size}_{args.aug_policy}_{args.loss}_{args.imbalance}_fz{args.freeze_stage}")
resolved = vars(args).copy()
resolved.update({"pos_weight": pos_weight, "n_train_neg": n_neg, "n_train_pos": n_pos,
"preprocess": pp.to_dict(), "device": device})
L.save_json(resolved, out / "resolved_config.json")
L.save_json(env, out / "env_info.json")
with open(out / "command_line.txt", "w") as f:
f.write("python " + " ".join(sys.argv) + "\n")
with open(out / "config_used.yaml", "w") as f:
yaml.safe_dump({k: v for k, v in resolved.items() if not isinstance(v, dict) or k == "preprocess"}, f)
use_trackio = not args.no_trackio
if use_trackio:
import trackio
try:
from trackio.alerts import AlertLevel
_LV = {"info": AlertLevel.INFO, "warn": AlertLevel.WARN, "error": AlertLevel.ERROR}
except Exception:
_LV = {"info": "info", "warn": "warn", "error": "error"}
def _alert(title, text, level="info"):
try:
trackio.alert(title, text, level=_LV.get(level, level))
except Exception as e:
print(f"[alert-failed] {title}: {text} ({e})", flush=True)
trackio.init(project=args.trackio_project, name=run_name,
space_id=args.trackio_space_id,
dataset_id=args.trackio_dataset_id,
config={k: v for k, v in resolved.items() if k != "preprocess"})
_alert("Run started",
f"{run_name} | backbone={args.backbone} loss={args.loss} "
f"imb={args.imbalance} lr={args.lr} wd={args.weight_decay} "
f"bs={args.batch_size} aug={args.aug_policy} fz={args.freeze_stage} "
f"pos_weight={pos_weight} device={env.get('gpu_name')}",
"info")
best_auroc = -1.0
best_epoch = -1
epochs_no_improve = 0
history = []
global_step = 0
n_warmup_steps = args.warmup_epochs * max(1, len(train_loader))
base_lr = args.lr
for epoch in range(args.epochs):
model.train()
t0 = time.time()
running = 0.0
for x, y, _ in train_loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True).float()
if global_step < n_warmup_steps and args.warmup_epochs > 0:
for pg in optimizer.param_groups:
pg["lr"] = base_lr * (global_step + 1) / n_warmup_steps
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type="cuda", dtype=torch.float16,
enabled=(args.amp and device == "cuda")):
out_logits = model(x).view(-1)
loss = criterion(out_logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running += loss.item() * x.size(0)
global_step += 1
train_loss = running / len(train_ds)
val_logits, val_labels, _ = L.collect_logits(model, valid_loader, device, amp=args.amp)
val_probs = L.sigmoid(val_logits)
val_auroc = float(roc_auc_score(val_labels, val_probs))
val_loss = float(F.binary_cross_entropy_with_logits(
torch.tensor(val_logits), torch.tensor(val_labels, dtype=torch.float32)).item())
m = L.point_metrics(val_labels, val_probs, 0.5)
cur_lr = optimizer.param_groups[0]["lr"]
if scheduler is not None:
if args.scheduler == "plateau":
scheduler.step(val_auroc)
elif global_step >= n_warmup_steps:
scheduler.step()
row = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss,
"val_auroc": val_auroc, "val_sens@0.5": m["sensitivity"],
"val_spec@0.5": m["specificity"], "val_ppv@0.5": m["ppv"],
"val_npv@0.5": m["npv"], "val_ece@0.5": m["ece"], "val_brier": m["brier"],
"lr": cur_lr, "epoch_time_s": round(time.time() - t0, 1)}
history.append(row)
print(f"[epoch {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
f"val_auroc={val_auroc:.4f} lr={cur_lr:.2e} ({row['epoch_time_s']}s)", flush=True)
if use_trackio:
trackio.log({"train_loss": train_loss, "val_loss": val_loss, "val_auroc": val_auroc,
"val_sensitivity": m["sensitivity"], "val_specificity": m["specificity"],
"val_ppv": m["ppv"], "val_npv": m["npv"], "val_ece": m["ece"],
"val_brier": m["brier"], "lr": cur_lr, "epoch": epoch})
improved = val_auroc > best_auroc + 1e-5
if improved:
best_auroc = val_auroc
best_epoch = epoch
epochs_no_improve = 0
torch.save({"model_state": model.state_dict(),
"backbone": args.backbone, "freeze_stage": args.freeze_stage,
"dropout": args.dropout, "preprocess": pp.to_dict(),
"epoch": epoch, "val_auroc": val_auroc},
out / "best_model.pt")
L.save_json({"best_epoch": best_epoch, "best_val_auroc": best_auroc,
"val_metrics_at_0.5": m}, out / "best_val_summary.json")
else:
epochs_no_improve += 1
if use_trackio and (not np.isfinite(train_loss) or train_loss > 1e3):
_alert("Training diverged",
f"train_loss={train_loss} at epoch {epoch} — lr likely too high, try x0.1",
"error")
if epochs_no_improve >= args.early_stop_patience:
print(f"Early stopping at epoch {epoch} (no val AUROC improvement for "
f"{args.early_stop_patience} epochs).", flush=True)
if use_trackio:
_alert("Early stopping",
f"No val AUROC gain for {args.early_stop_patience} epochs; "
f"best={best_auroc:.4f} @ epoch {best_epoch}. Consider lr x0.5.",
"warn")
break
L.save_json(history, out / "history.json")
L.save_json({"best_val_auroc": best_auroc, "best_epoch": best_epoch,
"run_name": run_name, "backbone": args.backbone},
out / "final_summary.json")
if use_trackio:
trackio.log({"best_val_auroc": best_auroc, "best_epoch": best_epoch})
_alert("Run complete",
f"{run_name}: best val AUROC={best_auroc:.4f} @ epoch {best_epoch}",
"info")
trackio.finish()
print(f"DONE best_val_auroc={best_auroc:.4f} best_epoch={best_epoch}", flush=True)
return best_auroc
if __name__ == "__main__":
main()