simplexuq-code / scripts /run_softmax.py
anonymous0523ly's picture
Initial anonymous code release
fc329a3 verified
raw
history blame
17.3 kB
"""Exp 2.2 — Classification softmax calibration on CIFAR-10/100.
Softmax output ∈ Δ^{K-1}, one-hot label ∈ Δ^{K-1}.
Tests whether global conformal creates disparity across easy vs hard classes.
Usage:
python scripts/run_softmax.py --dataset cifar10
python scripts/run_softmax.py --dataset cifar100 --n_strata 10
"""
import argparse
import json
import logging
import numpy as np
from pathlib import Path
import time
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.utils.simplex import aitchison_dist
from src.utils.strata import (
precompute_fixed_strata,
stratify_by_boundary,
stratify_by_entropy,
)
from src.utils.seed import get_rng
from src.methods import (
full_conformal,
global_split_conformal,
jackknife_plus_conformal,
oneshot_conformal,
partition_conformal,
trainres_conformal,
twostage_conformal,
weighted_conformal,
)
from src.methods._knn_sigma import knn_sigma_hat, knn_sigma_leave_one_out
from src.metrics.coverage import (
coverage_variance,
marginal_coverage,
max_disparity,
stratified_coverage,
worst_stratum_coverage,
)
from src.metrics.sscv import size_stratified_coverage_violation
from src.metrics.setsize import mean_radius, mean_volume_ratio, volume_ratio_by_strata
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
DEFAULT_METHODS = [
"global",
"partition",
"twostage",
"jackknife_plus",
"weighted",
"oneshot",
"trainres",
]
def get_softmax_predictions(dataset: str, model_name: str = "resnet50",
device: str = "cuda"):
"""Train or load a classifier, return softmax predictions on test set.
Returns:
Y: one-hot labels (n, K)
U: softmax predictions (n, K)
class_names: list of class names
"""
# Check for cached predictions
cache_path = Path(f"data/processed/{dataset}_{model_name}_softmax.npz")
if cache_path.exists():
log.info(f"Loading cached predictions from {cache_path}")
data = np.load(cache_path)
return data["Y"], data["U"], list(data["class_names"])
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
# Load dataset
if dataset == "cifar10":
transform = T.Compose([T.Resize(224), T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
testset = torchvision.datasets.CIFAR10(
root="data/raw", train=False, download=True, transform=transform)
trainset = torchvision.datasets.CIFAR10(
root="data/raw", train=True, download=True, transform=transform)
K = 10
class_names = testset.classes
elif dataset == "cifar100":
transform = T.Compose([T.Resize(224), T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
testset = torchvision.datasets.CIFAR100(
root="data/raw", train=False, download=True, transform=transform)
trainset = torchvision.datasets.CIFAR100(
root="data/raw", train=True, download=True, transform=transform)
K = 100
class_names = testset.classes
else:
raise ValueError(f"Unknown dataset: {dataset}")
log.info(f"Training/loading {model_name} on {dataset}...")
# Use pretrained model + finetune last layer
if model_name == "resnet50":
model = torchvision.models.resnet50(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, K)
elif model_name == "resnet18":
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, K)
else:
raise ValueError(f"Unknown model: {model_name}")
model = model.to(device)
# Quick finetune (5 epochs, enough for reasonable softmax)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(5):
total_loss = 0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
log.info(f" Epoch {epoch+1}/5, loss={total_loss/len(trainloader):.4f}")
# Get test predictions
model.eval()
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
shuffle=False, num_workers=4)
all_probs = []
all_labels = []
with torch.no_grad():
for images, labels in testloader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1).cpu().numpy()
all_probs.append(probs)
all_labels.append(labels.numpy())
U = np.concatenate(all_probs) # (n, K) softmax predictions
labels = np.concatenate(all_labels) # (n,) integer labels
# One-hot encode labels (these are vertices of the simplex)
Y = np.zeros((len(labels), K))
Y[np.arange(len(labels)), labels] = 1.0
# Add tiny smoothing to avoid log(0) in Aitchison distance
Y = (Y + 1e-8)
Y = Y / Y.sum(axis=1, keepdims=True)
acc = (np.argmax(U, axis=1) == labels).mean()
log.info(f"Test accuracy: {acc:.4f}")
# Cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
np.savez(cache_path, Y=Y, U=U, class_names=np.array(class_names))
log.info(f"Cached predictions to {cache_path}")
return Y, U, class_names
def compute_weight_vectors(R_cal, U_cal, U_test, k=20):
sigma_cal = knn_sigma_leave_one_out(U_cal, R_cal, k=k)
sigma_test = knn_sigma_hat(U_cal, R_cal, U_test, k=k)
weights_cal = 1.0 / np.maximum(sigma_cal, 1e-8)
weights_test = 1.0 / np.maximum(sigma_test, 1e-8)
weights_cal /= np.mean(weights_cal)
weights_test /= np.mean(weights_test)
return weights_cal, weights_test
def evaluate_result(
res,
U_test,
strata_test,
alpha,
runtime_sec,
compute_volume=False,
volume_score="tv",
volume_n_mc=50000,
volume_max_points=None,
rep=0,
):
result = dict(
marginal_coverage=float(marginal_coverage(res.covered)),
max_disparity=float(max_disparity(res.covered, strata_test, alpha)),
worst_stratum_coverage=float(worst_stratum_coverage(res.covered, strata_test)),
mean_radius=float(mean_radius(res.radius)),
sscv=float(size_stratified_coverage_violation(res.covered, res.radius, alpha)),
coverage_variance=float(coverage_variance(res.covered, strata_test)),
runtime_sec=float(runtime_sec),
stratified_coverage={
str(k): float(v) for k, v in stratified_coverage(res.covered, strata_test).items()
},
)
if compute_volume:
result["mean_volume_ratio"] = float(
mean_volume_ratio(
U_test,
res.radius,
score=volume_score,
n_mc=volume_n_mc,
max_points=volume_max_points,
rng=np.random.default_rng(rep),
)
)
result["volume_ratio_by_strata"] = {
str(k): float(v)
for k, v in volume_ratio_by_strata(
U_test,
res.radius,
strata_test,
score=volume_score,
n_mc=volume_n_mc,
max_points=volume_max_points,
rng=np.random.default_rng(rep),
).items()
}
return result
def run_experiment(
Y,
U,
alpha,
n_rep,
cal_frac,
n_strata,
rng,
methods,
compute_volume=False,
volume_score="tv",
volume_n_mc=50000,
volume_max_points=None,
strata_method="entropy",
fixed_strata=True,
strata_seed=2026,
):
"""Run conformal with repeated splits."""
# Use L1 distance instead of Aitchison for one-hot labels
# (Aitchison is ill-defined at simplex vertices)
R = np.sum(np.abs(Y - U), axis=1) / 2.0 # total variation distance
n = len(R)
n_cal = int(n * cal_frac)
all_results = {m: [] for m in methods}
fixed_labels = None
if fixed_strata:
fixed_labels = precompute_fixed_strata(U, strata_method, n_strata, seed=strata_seed)
elif strata_method not in {"boundary", "entropy"}:
raise ValueError("Non-fixed softmax strata must be 'boundary' or 'entropy'.")
for rep in range(n_rep):
perm = rng.permutation(n)
idx_cal, idx_test = perm[:n_cal], perm[n_cal:]
R_cal, R_test = R[idx_cal], R[idx_test]
U_cal, U_test = U[idx_cal], U[idx_test]
if fixed_labels is not None:
strata_cal = fixed_labels[idx_cal]
strata_test = fixed_labels[idx_test]
else:
strata_fn = stratify_by_boundary if strata_method == "boundary" else stratify_by_entropy
strata_cal = strata_fn(U_cal, n_strata)
strata_test = strata_fn(U_test, n_strata)
weights_cal, weights_test = compute_weight_vectors(R_cal, U_cal, U_test)
for m in methods:
start = time.perf_counter()
if m == "global":
res = global_split_conformal(R_cal, R_test, alpha)
elif m == "partition":
res = partition_conformal(R_cal, R_test, alpha,
strata_cal, strata_test)
elif m == "twostage":
res = twostage_conformal(R_cal, R_test, alpha,
U_cal, U_test)
elif m == "jackknife_plus":
res = jackknife_plus_conformal(R_cal, R_test, alpha, U_cal=U_cal, U_test=U_test)
elif m == "weighted":
res = weighted_conformal(R_cal, R_test, alpha, weights_cal, weights_test)
elif m == "oneshot":
res = oneshot_conformal(R_cal, R_test, alpha, U_cal, U_test)
elif m == "trainres":
train_perm = rng.permutation(n)
idx_train = train_perm[:n_cal]
res = trainres_conformal(
R_cal, R_test, alpha, U_cal, U_test, R[idx_train], U[idx_train]
)
elif m == "fullcp":
res = full_conformal(R_cal, R_test, alpha, U_cal, U_test)
else:
continue
runtime_sec = time.perf_counter() - start
all_results[m].append(
evaluate_result(
res,
U_test,
strata_test,
alpha,
runtime_sec,
compute_volume=compute_volume,
volume_score=volume_score,
volume_n_mc=volume_n_mc,
volume_max_points=volume_max_points,
rep=rep,
)
)
if (rep + 1) % 50 == 0:
log.info(f" Rep {rep + 1}/{n_rep}")
return all_results
def maybe_subsample(Y, U, max_samples, rng):
if max_samples is None or max_samples >= len(Y):
return Y, U
idx = rng.choice(len(Y), size=max_samples, replace=False)
return Y[idx], U[idx]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="cifar10", choices=["cifar10", "cifar100"])
parser.add_argument("--model", default="resnet18")
parser.add_argument("--device", default="cuda")
parser.add_argument("--alpha", type=float, default=0.1)
parser.add_argument("--n_rep", type=int, default=200)
parser.add_argument("--cal_frac", type=float, default=0.4)
parser.add_argument("--n_strata", type=int, default=5)
parser.add_argument(
"--strata",
choices=["entropy", "boundary", "dominant", "kmeans", "random"],
default="entropy",
)
parser.add_argument("--fixed-strata", dest="fixed_strata", action="store_true")
parser.add_argument(
"--separate-strata",
dest="fixed_strata",
action="store_false",
help="Diagnostic only: fit calibration/test strata separately.",
)
parser.set_defaults(fixed_strata=True)
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--compute-volume", action="store_true")
parser.add_argument("--volume-score", choices=["tv", "aitchison"], default="tv")
parser.add_argument("--volume-n-mc", type=int, default=50000)
parser.add_argument("--volume-max-points", type=int, default=None)
parser.add_argument(
"--methods",
nargs="+",
default=DEFAULT_METHODS,
choices=DEFAULT_METHODS + ["fullcp"],
)
parser.add_argument("--tag", default=None)
parser.add_argument("--seed", type=int, default=2026)
parser.add_argument("--output-dir", default="results")
args = parser.parse_args()
rng = get_rng(args.seed)
# Get predictions
Y, U, class_names = get_softmax_predictions(args.dataset, args.model, args.device)
Y, U = maybe_subsample(Y, U, args.max_samples, rng)
K = Y.shape[1]
log.info(f"Dataset: {args.dataset}, K={K}, n={len(Y)}")
# Residual diagnostics
R = np.sum(np.abs(Y - U), axis=1) / 2.0
log.info(f"Residuals: mean={R.mean():.4f}, std={R.std():.4f}")
# Per-class residuals
true_labels = np.argmax(Y, axis=1)
for k in range(min(K, 10)):
mask = true_labels == k
log.info(f" {class_names[k]:12s}: n={mask.sum()}, "
f"R_mean={R[mask].mean():.4f}, R_std={R[mask].std():.4f}")
# Run
all_results = run_experiment(
Y,
U,
args.alpha,
args.n_rep,
args.cal_frac,
args.n_strata,
rng,
args.methods,
compute_volume=args.compute_volume,
volume_score=args.volume_score,
volume_n_mc=args.volume_n_mc,
volume_max_points=args.volume_max_points,
strata_method=args.strata,
fixed_strata=args.fixed_strata,
strata_seed=args.seed,
)
# Aggregate
log.info("\n" + "=" * 60)
log.info(f"RESULTS — Softmax calibration ({args.dataset})")
log.info("=" * 60)
summary = {}
scalar_keys = [
"marginal_coverage",
"max_disparity",
"worst_stratum_coverage",
"mean_radius",
"sscv",
"coverage_variance",
"runtime_sec",
"mean_volume_ratio",
]
for m in args.methods:
if not all_results[m]:
continue
reps = all_results[m]
s = {}
for key in scalar_keys:
if key in reps[0]:
vals = [r[key] for r in reps]
s[key] = {"mean": float(np.mean(vals)), "std": float(np.std(vals))}
strata_keys = set()
for r in reps:
strata_keys.update(r["stratified_coverage"].keys())
s["stratified_coverage"] = {
k: {
"mean": float(np.mean([r["stratified_coverage"][k] for r in reps if k in r["stratified_coverage"]])),
"std": float(np.std([r["stratified_coverage"][k] for r in reps if k in r["stratified_coverage"]])),
"n_reps": int(sum(k in r["stratified_coverage"] for r in reps)),
}
for k in sorted(strata_keys, key=int)
}
if "volume_ratio_by_strata" in reps[0]:
vol_keys = set()
for r in reps:
vol_keys.update(r["volume_ratio_by_strata"].keys())
s["volume_ratio_by_strata"] = {
k: {
"mean": float(np.mean([r["volume_ratio_by_strata"][k] for r in reps if k in r["volume_ratio_by_strata"]])),
"std": float(np.std([r["volume_ratio_by_strata"][k] for r in reps if k in r["volume_ratio_by_strata"]])),
"n_reps": int(sum(k in r["volume_ratio_by_strata"] for r in reps)),
}
for k in sorted(vol_keys, key=int)
}
summary[m] = s
log.info(
f" {m:12s} cov={s['marginal_coverage']['mean']:.3f}±{s['marginal_coverage']['std']:.3f} "
f"disp={s['max_disparity']['mean']:.3f}±{s['max_disparity']['std']:.3f} "
f"worst={s['worst_stratum_coverage']['mean']:.3f} "
f"sscv={s['sscv']['mean']:.3f}"
)
# Save
out_dir = Path(args.output_dir) / "tables"
out_dir.mkdir(parents=True, exist_ok=True)
suffix = f"_{args.tag}" if args.tag else ""
out_file = out_dir / f"exp2_2_softmax_{args.dataset}{suffix}.json"
with open(out_file, "w") as f:
json.dump(dict(summary=summary, dataset=args.dataset, K=K,
class_names=class_names, config=vars(args), raw=all_results),
f, indent=2)
log.info(f"Saved to {out_file}")
if __name__ == "__main__":
main()