dpyd-classifier / src /train.py
abhimanyu12's picture
Upload folder using huggingface_hub
2ea06dc verified
Raw
History Blame Contribute Delete
7.19 kB
"""Train the DPYD baseline classifier — RandomForest, XGBoost, LightGBM.
BASELINE: AF + CPIC/ClinVar categorical features only (no SIFT/PolyPhen/CADD).
5-fold stratified CV. Reports accuracy, per-class F1, and macro one-vs-rest
AUC-ROC. Persists fitted models (refit on full data) + cv_metrics.json.
HONESTY GUARD: the CPIC ground-truth set for DPYD is tiny (~13 alleles). With
3 classes and AF+categorical features only, CV metrics on this set are
indicative, not validating. The guard below refuses to emit metrics silently
if any class has fewer than MIN_PER_CLASS examples — it prints a loud caveat
and (optionally) reduces n_splits. This mirrors the platform invariant:
"refuse honestly when evidence is thin."
Run: python -m src.train --train data/training_data.csv --outdir . --upload <bucket>
DO NOT run as part of scaffolding — invoked explicitly at training time.
"""
from __future__ import annotations
import argparse, json, sys
from collections import Counter
import numpy as np
import pandas as pd
CLASSES = ["normal_function", "decreased_function", "no_function"]
MIN_PER_CLASS = 5 # below this, CV is noise; we say so loudly.
NUMERIC = ["gnomad_global_af", "gnomad_sas_af", "log10_gnomad_global_af",
"log10_gnomad_sas_af", "in_gnomad", "sas_enriched", "is_indel"]
CATEGORICAL = ["consequence", "clnsig_norm"]
def _xy(df: pd.DataFrame):
df = df[df["label_class"].isin(CLASSES)].copy()
X_num = df[NUMERIC].apply(pd.to_numeric, errors="coerce").fillna(0.0)
X_cat = pd.get_dummies(df[CATEGORICAL].astype(str), prefix=CATEGORICAL)
X = pd.concat([X_num, X_cat], axis=1)
y = df["label_class"].astype("category")
return X, y, list(X.columns)
def _models():
from sklearn.ensemble import RandomForestClassifier
models = {"rf": RandomForestClassifier(n_estimators=400, random_state=42,
class_weight="balanced", n_jobs=-1)}
try:
from xgboost import XGBClassifier
models["xgb"] = XGBClassifier(n_estimators=400, max_depth=4, learning_rate=0.05,
subsample=0.9, colsample_bytree=0.9,
objective="multi:softprob", random_state=42,
eval_metric="mlogloss", tree_method="hist")
except ImportError:
print("WARN: xgboost not installed — skipping xgb")
try:
from lightgbm import LGBMClassifier
models["lgbm"] = LGBMClassifier(n_estimators=400, max_depth=-1, learning_rate=0.05,
subsample=0.9, colsample_bytree=0.9,
class_weight="balanced", random_state=42, verbose=-1)
except ImportError:
print("WARN: lightgbm not installed — skipping lgbm")
return models
def cross_validate(X, y, n_splits=5):
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.preprocessing import label_binarize
classes_present = list(y.cat.categories)
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
out = {}
for name, model in _models().items():
accs, f1s, aucs = [], {c: [] for c in classes_present}, []
for tr, te in skf.split(X, y):
m = _clone(model)
m.fit(X.iloc[tr], y.iloc[tr].cat.codes)
pred = m.predict(X.iloc[te])
ytrue = y.iloc[te].cat.codes.to_numpy()
accs.append(accuracy_score(ytrue, pred))
f1c = f1_score(ytrue, pred, average=None,
labels=list(range(len(classes_present))), zero_division=0)
for i, c in enumerate(classes_present):
f1s[c].append(float(f1c[i]))
try:
proba = m.predict_proba(X.iloc[te])
yb = label_binarize(ytrue, classes=list(range(len(classes_present))))
aucs.append(float(roc_auc_score(yb, proba, average="macro", multi_class="ovr")))
except Exception:
pass
out[name] = {
"accuracy_mean": float(np.mean(accs)), "accuracy_std": float(np.std(accs)),
"f1_per_class": {c: float(np.mean(v)) for c, v in f1s.items()},
"auc_roc_macro_ovr": float(np.mean(aucs)) if aucs else None,
"n_splits": n_splits,
}
return out, classes_present
def _clone(m):
from sklearn.base import clone
return clone(m)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--train", default="data/training_data.csv")
ap.add_argument("--outdir", default=".")
ap.add_argument("--n-splits", type=int, default=5)
ap.add_argument("--upload", default="", help="gs:// bucket prefix to upload to (optional)")
ap.add_argument("--force", action="store_true", help="train even if class counts are below the guard")
args = ap.parse_args()
df = pd.read_csv(args.train)
X, y, feats = _xy(df)
counts = Counter(y)
print(f"training rows: {len(X)} | features: {len(feats)} | class counts: {dict(counts)}")
# --- honesty guard ---
too_small = [c for c in CLASSES if counts.get(c, 0) < MIN_PER_CLASS]
n_splits = args.n_splits
if too_small:
msg = (f"\n*** EVIDENCE-THIN WARNING ***\n"
f"Classes below MIN_PER_CLASS={MIN_PER_CLASS}: {too_small}.\n"
f"5-fold CV on this set produces indicative, NOT validating, metrics.\n"
f"Report all numbers as 'baseline classifier, AF+categorical features, "
f"small-N (illustrative)'.\n")
print(msg)
n_splits = max(2, min(n_splits, min(counts.values())))
if not args.force:
print("Refusing to silently emit CV metrics. Re-run with --force to proceed "
"(metrics will carry the small-N caveat in cv_metrics.json).")
sys.exit(2)
metrics, classes_present = cross_validate(X, y, n_splits=n_splits)
payload = {
"model_label": "baseline classifier, AF+categorical features (no SIFT/PolyPhen/CADD)",
"feature_columns": feats,
"classes": classes_present,
"class_counts": dict(counts),
"evidence_caveat": ("small-N (illustrative)" if too_small else "ok"),
"cv": metrics,
}
import os
os.makedirs(f"{args.outdir}/results", exist_ok=True)
os.makedirs(f"{args.outdir}/models", exist_ok=True)
with open(f"{args.outdir}/results/cv_metrics.json", "w") as f:
json.dump(payload, f, indent=2)
print("wrote results/cv_metrics.json")
# refit on full data + persist
import joblib
for name, model in _models().items():
m = _clone(model)
m.fit(X, y.cat.codes)
joblib.dump({"model": m, "features": feats, "classes": classes_present},
f"{args.outdir}/models/{name}_model.pkl")
print(f"saved models/{name}_model.pkl")
if args.upload:
from src.gcs_io import upload_dir
upload_dir(f"{args.outdir}/models", f"{args.upload}/models")
upload_dir(f"{args.outdir}/results", f"{args.upload}/results")
if __name__ == "__main__":
main()