new-latern-space / eval /train_classifier.py
Wothmag07's picture
Drop LLM tier; deterministic CV; parser unit tests; schema-drift guard
b91324f
"""Train a LightGBM classifier on the public-split labels and report honest
out-of-sample accuracy via GroupKFold (groups = unique (curr, prior) text pair).
Usage:
python -m eval.train_classifier # 5-fold CV, no model saved
python -m eval.train_classifier --save # 5-fold CV + train on full data + save
python -m eval.train_classifier --folds 10 # 10-fold instead of 5
The grouping by description pair is deliberate: it tells us how the model
generalises to unseen description pairs, which matches the private-split
inference setting (private split contains some description pairs we have
never trained on).
"""
from __future__ import annotations
import argparse
import json
import pickle
import sys
import time
from collections import Counter
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
import numpy as np # noqa: E402
import lightgbm as lgb # noqa: E402
from sklearn.model_selection import GroupKFold # noqa: E402
from app.features import featurize, feature_names # noqa: E402
DEFAULT_JSON = ROOT / "relevant_priors_public.json"
DEFAULT_MODEL_PATH = ROOT / "app" / "classifier_model.pkl"
def load_pairs(json_path: Path):
with json_path.open("r", encoding="utf-8") as f:
data = json.load(f)
truth = {(t["case_id"], t["study_id"]): t["is_relevant_to_current"] for t in data["truth"]}
rows = []
for c in data["cases"]:
cd = c["current_study"]["study_description"]
cdate = c["current_study"].get("study_date")
for p in c["prior_studies"]:
key = (c["case_id"], p["study_id"])
if key not in truth:
continue
pd = p["study_description"]
pdate = p.get("study_date")
label = truth[key]
rows.append({
"case_id": c["case_id"],
"study_id": p["study_id"],
"curr_desc": cd,
"prior_desc": pd,
"curr_date": cdate,
"prior_date": pdate,
"label": label,
})
return rows
def build_matrix(rows):
names = feature_names()
X = np.zeros((len(rows), len(names)), dtype=np.float32)
y = np.zeros(len(rows), dtype=np.int8)
for i, r in enumerate(rows):
fb = featurize(r["curr_desc"], r["prior_desc"], r["curr_date"], r["prior_date"])
X[i, :] = fb.values
y[i] = int(r["label"])
return X, y, names
DEFAULT_SEED = 42
def train_one(X_train, y_train, params=None, seed: int = DEFAULT_SEED):
p = {
"objective": "binary",
"metric": "binary_logloss",
"learning_rate": 0.05,
"num_leaves": 31,
"min_data_in_leaf": 30,
"feature_fraction": 0.9,
"bagging_fraction": 0.9,
"bagging_freq": 5,
"verbose": -1,
"n_jobs": 1, # avoid the small non-determinism histogram-build
# has at high thread counts
"seed": seed,
"feature_fraction_seed": seed,
"bagging_seed": seed,
"data_random_seed": seed,
"deterministic": True,
}
if params:
p.update(params)
train_set = lgb.Dataset(X_train, label=y_train)
model = lgb.train(p, train_set, num_boost_round=300)
return model
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--json", default=str(DEFAULT_JSON))
ap.add_argument("--folds", type=int, default=5)
ap.add_argument("--save", action="store_true")
ap.add_argument("--model-path", default=str(DEFAULT_MODEL_PATH))
ap.add_argument("--threshold", type=float, default=0.5,
help="probability threshold for converting to bool")
args = ap.parse_args()
print(f"Loading {args.json} ...")
rows = load_pairs(Path(args.json))
print(f" rows: {len(rows)}")
print(f" positive rate: {sum(r['label'] for r in rows) / len(rows):.4f}")
print("Building feature matrix ...")
t0 = time.perf_counter()
X, y, names = build_matrix(rows)
print(f" X shape: {X.shape} ({time.perf_counter() - t0:.1f}s)")
# Group by unique (curr_desc, prior_desc) tuple — same text-pair never spans
# folds. NOTE: Python's built-in hash() is randomized per process
# (PYTHONHASHSEED), so it is NOT a stable group identifier across runs. Use
# the joined string itself; sklearn accepts arrays of arbitrary hashables.
groups = np.array([f"{r['curr_desc']}||{r['prior_desc']}" for r in rows])
n_groups = len(set(groups))
print(f" unique (curr,prior) text-pair groups: {n_groups}")
print(f"\n=== {args.folds}-fold GroupKFold CV ===")
gkf = GroupKFold(n_splits=args.folds)
oof_proba = np.zeros(len(rows), dtype=np.float32)
for fold, (tr_idx, te_idx) in enumerate(gkf.split(X, y, groups), start=1):
t0 = time.perf_counter()
model = train_one(X[tr_idx], y[tr_idx])
proba = model.predict(X[te_idx])
oof_proba[te_idx] = proba
pred = (proba >= args.threshold).astype(int)
acc = (pred == y[te_idx]).mean()
print(f" fold {fold}/{args.folds} train={len(tr_idx):5d} test={len(te_idx):5d} "
f"acc={acc:.4f} ({time.perf_counter() - t0:.1f}s)")
pred = (oof_proba >= args.threshold).astype(int)
correct = (pred == y).sum()
total = len(y)
acc = correct / total
confusion = Counter()
for t, p in zip(y, pred):
confusion[(bool(t), bool(p))] += 1
print(f"\n=== Out-of-fold totals (threshold={args.threshold}) ===")
print(f"accuracy: {acc:.4f} ({correct}/{total})")
print("confusion (true, pred): count")
for (t, p), n in sorted(confusion.items()):
print(f" ({t!s:5}, {p!s:5}) {n}")
# Threshold sweep — useful to see if 0.5 is optimal
print("\n=== Threshold sweep on OOF predictions ===")
print(f"{'thr':>5} {'acc':>7} {'TP':>6} {'FP':>6} {'TN':>6} {'FN':>6}")
for thr in [0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70]:
p = (oof_proba >= thr).astype(int)
a = (p == y).mean()
tp = ((p == 1) & (y == 1)).sum()
fp = ((p == 1) & (y == 0)).sum()
tn = ((p == 0) & (y == 0)).sum()
fn = ((p == 0) & (y == 1)).sum()
print(f"{thr:>5.2f} {a:>7.4f} {tp:>6d} {fp:>6d} {tn:>6d} {fn:>6d}")
# Top features by gain
print("\n=== Top 20 features by gain (last fold) ===")
importance = sorted(zip(names, model.feature_importance(importance_type="gain")),
key=lambda x: -x[1])
for n, imp in importance[:20]:
print(f" {n:30s} {imp:>10.1f}")
if args.save:
print(f"\nTraining final model on full data and saving ...")
model = train_one(X, y)
with open(args.model_path, "wb") as f:
pickle.dump({"model": model, "feature_names": names, "threshold": args.threshold}, f)
print(f" saved to {args.model_path}")
if __name__ == "__main__":
main()