"""Train a LightGBM classifier on the public-split labels and report honest out-of-sample accuracy via GroupKFold (groups = unique (curr, prior) text pair). Usage: python -m eval.train_classifier # 5-fold CV, no model saved python -m eval.train_classifier --save # 5-fold CV + train on full data + save python -m eval.train_classifier --folds 10 # 10-fold instead of 5 The grouping by description pair is deliberate: it tells us how the model generalises to unseen description pairs, which matches the private-split inference setting (private split contains some description pairs we have never trained on). """ from __future__ import annotations import argparse import json import pickle import sys import time from collections import Counter from pathlib import Path ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) import numpy as np # noqa: E402 import lightgbm as lgb # noqa: E402 from sklearn.model_selection import GroupKFold # noqa: E402 from app.features import featurize, feature_names # noqa: E402 DEFAULT_JSON = ROOT / "relevant_priors_public.json" DEFAULT_MODEL_PATH = ROOT / "app" / "classifier_model.pkl" def load_pairs(json_path: Path): with json_path.open("r", encoding="utf-8") as f: data = json.load(f) truth = {(t["case_id"], t["study_id"]): t["is_relevant_to_current"] for t in data["truth"]} rows = [] for c in data["cases"]: cd = c["current_study"]["study_description"] cdate = c["current_study"].get("study_date") for p in c["prior_studies"]: key = (c["case_id"], p["study_id"]) if key not in truth: continue pd = p["study_description"] pdate = p.get("study_date") label = truth[key] rows.append({ "case_id": c["case_id"], "study_id": p["study_id"], "curr_desc": cd, "prior_desc": pd, "curr_date": cdate, "prior_date": pdate, "label": label, }) return rows def build_matrix(rows): names = feature_names() X = np.zeros((len(rows), len(names)), dtype=np.float32) y = np.zeros(len(rows), dtype=np.int8) for i, r in enumerate(rows): fb = featurize(r["curr_desc"], r["prior_desc"], r["curr_date"], r["prior_date"]) X[i, :] = fb.values y[i] = int(r["label"]) return X, y, names DEFAULT_SEED = 42 def train_one(X_train, y_train, params=None, seed: int = DEFAULT_SEED): p = { "objective": "binary", "metric": "binary_logloss", "learning_rate": 0.05, "num_leaves": 31, "min_data_in_leaf": 30, "feature_fraction": 0.9, "bagging_fraction": 0.9, "bagging_freq": 5, "verbose": -1, "n_jobs": 1, # avoid the small non-determinism histogram-build # has at high thread counts "seed": seed, "feature_fraction_seed": seed, "bagging_seed": seed, "data_random_seed": seed, "deterministic": True, } if params: p.update(params) train_set = lgb.Dataset(X_train, label=y_train) model = lgb.train(p, train_set, num_boost_round=300) return model def main(): ap = argparse.ArgumentParser() ap.add_argument("--json", default=str(DEFAULT_JSON)) ap.add_argument("--folds", type=int, default=5) ap.add_argument("--save", action="store_true") ap.add_argument("--model-path", default=str(DEFAULT_MODEL_PATH)) ap.add_argument("--threshold", type=float, default=0.5, help="probability threshold for converting to bool") args = ap.parse_args() print(f"Loading {args.json} ...") rows = load_pairs(Path(args.json)) print(f" rows: {len(rows)}") print(f" positive rate: {sum(r['label'] for r in rows) / len(rows):.4f}") print("Building feature matrix ...") t0 = time.perf_counter() X, y, names = build_matrix(rows) print(f" X shape: {X.shape} ({time.perf_counter() - t0:.1f}s)") # Group by unique (curr_desc, prior_desc) tuple — same text-pair never spans # folds. NOTE: Python's built-in hash() is randomized per process # (PYTHONHASHSEED), so it is NOT a stable group identifier across runs. Use # the joined string itself; sklearn accepts arrays of arbitrary hashables. groups = np.array([f"{r['curr_desc']}||{r['prior_desc']}" for r in rows]) n_groups = len(set(groups)) print(f" unique (curr,prior) text-pair groups: {n_groups}") print(f"\n=== {args.folds}-fold GroupKFold CV ===") gkf = GroupKFold(n_splits=args.folds) oof_proba = np.zeros(len(rows), dtype=np.float32) for fold, (tr_idx, te_idx) in enumerate(gkf.split(X, y, groups), start=1): t0 = time.perf_counter() model = train_one(X[tr_idx], y[tr_idx]) proba = model.predict(X[te_idx]) oof_proba[te_idx] = proba pred = (proba >= args.threshold).astype(int) acc = (pred == y[te_idx]).mean() print(f" fold {fold}/{args.folds} train={len(tr_idx):5d} test={len(te_idx):5d} " f"acc={acc:.4f} ({time.perf_counter() - t0:.1f}s)") pred = (oof_proba >= args.threshold).astype(int) correct = (pred == y).sum() total = len(y) acc = correct / total confusion = Counter() for t, p in zip(y, pred): confusion[(bool(t), bool(p))] += 1 print(f"\n=== Out-of-fold totals (threshold={args.threshold}) ===") print(f"accuracy: {acc:.4f} ({correct}/{total})") print("confusion (true, pred): count") for (t, p), n in sorted(confusion.items()): print(f" ({t!s:5}, {p!s:5}) {n}") # Threshold sweep — useful to see if 0.5 is optimal print("\n=== Threshold sweep on OOF predictions ===") print(f"{'thr':>5} {'acc':>7} {'TP':>6} {'FP':>6} {'TN':>6} {'FN':>6}") for thr in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70]: p = (oof_proba >= thr).astype(int) a = (p == y).mean() tp = ((p == 1) & (y == 1)).sum() fp = ((p == 1) & (y == 0)).sum() tn = ((p == 0) & (y == 0)).sum() fn = ((p == 0) & (y == 1)).sum() print(f"{thr:>5.2f} {a:>7.4f} {tp:>6d} {fp:>6d} {tn:>6d} {fn:>6d}") # Top features by gain print("\n=== Top 20 features by gain (last fold) ===") importance = sorted(zip(names, model.feature_importance(importance_type="gain")), key=lambda x: -x[1]) for n, imp in importance[:20]: print(f" {n:30s} {imp:>10.1f}") if args.save: print(f"\nTraining final model on full data and saving ...") model = train_one(X, y) with open(args.model_path, "wb") as f: pickle.dump({"model": model, "feature_names": names, "threshold": args.threshold}, f) print(f" saved to {args.model_path}") if __name__ == "__main__": main()