| """Train the DPYD baseline classifier — RandomForest, XGBoost, LightGBM. |
| |
| BASELINE: AF + CPIC/ClinVar categorical features only (no SIFT/PolyPhen/CADD). |
| 5-fold stratified CV. Reports accuracy, per-class F1, and macro one-vs-rest |
| AUC-ROC. Persists fitted models (refit on full data) + cv_metrics.json. |
| |
| HONESTY GUARD: the CPIC ground-truth set for DPYD is tiny (~13 alleles). With |
| 3 classes and AF+categorical features only, CV metrics on this set are |
| indicative, not validating. The guard below refuses to emit metrics silently |
| if any class has fewer than MIN_PER_CLASS examples — it prints a loud caveat |
| and (optionally) reduces n_splits. This mirrors the platform invariant: |
| "refuse honestly when evidence is thin." |
| |
| Run: python -m src.train --train data/training_data.csv --outdir . --upload <bucket> |
| DO NOT run as part of scaffolding — invoked explicitly at training time. |
| """ |
| from __future__ import annotations |
| import argparse, json, sys |
| from collections import Counter |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| CLASSES = ["normal_function", "decreased_function", "no_function"] |
| MIN_PER_CLASS = 5 |
|
|
| NUMERIC = ["gnomad_global_af", "gnomad_sas_af", "log10_gnomad_global_af", |
| "log10_gnomad_sas_af", "in_gnomad", "sas_enriched", "is_indel"] |
| CATEGORICAL = ["consequence", "clnsig_norm"] |
|
|
|
|
| def _xy(df: pd.DataFrame): |
| df = df[df["label_class"].isin(CLASSES)].copy() |
| X_num = df[NUMERIC].apply(pd.to_numeric, errors="coerce").fillna(0.0) |
| X_cat = pd.get_dummies(df[CATEGORICAL].astype(str), prefix=CATEGORICAL) |
| X = pd.concat([X_num, X_cat], axis=1) |
| y = df["label_class"].astype("category") |
| return X, y, list(X.columns) |
|
|
|
|
| def _models(): |
| from sklearn.ensemble import RandomForestClassifier |
| models = {"rf": RandomForestClassifier(n_estimators=400, random_state=42, |
| class_weight="balanced", n_jobs=-1)} |
| try: |
| from xgboost import XGBClassifier |
| models["xgb"] = XGBClassifier(n_estimators=400, max_depth=4, learning_rate=0.05, |
| subsample=0.9, colsample_bytree=0.9, |
| objective="multi:softprob", random_state=42, |
| eval_metric="mlogloss", tree_method="hist") |
| except ImportError: |
| print("WARN: xgboost not installed — skipping xgb") |
| try: |
| from lightgbm import LGBMClassifier |
| models["lgbm"] = LGBMClassifier(n_estimators=400, max_depth=-1, learning_rate=0.05, |
| subsample=0.9, colsample_bytree=0.9, |
| class_weight="balanced", random_state=42, verbose=-1) |
| except ImportError: |
| print("WARN: lightgbm not installed — skipping lgbm") |
| return models |
|
|
|
|
| def cross_validate(X, y, n_splits=5): |
| from sklearn.model_selection import StratifiedKFold |
| from sklearn.metrics import f1_score, accuracy_score, roc_auc_score |
| from sklearn.preprocessing import label_binarize |
|
|
| classes_present = list(y.cat.categories) |
| skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
| out = {} |
| for name, model in _models().items(): |
| accs, f1s, aucs = [], {c: [] for c in classes_present}, [] |
| for tr, te in skf.split(X, y): |
| m = _clone(model) |
| m.fit(X.iloc[tr], y.iloc[tr].cat.codes) |
| pred = m.predict(X.iloc[te]) |
| ytrue = y.iloc[te].cat.codes.to_numpy() |
| accs.append(accuracy_score(ytrue, pred)) |
| f1c = f1_score(ytrue, pred, average=None, |
| labels=list(range(len(classes_present))), zero_division=0) |
| for i, c in enumerate(classes_present): |
| f1s[c].append(float(f1c[i])) |
| try: |
| proba = m.predict_proba(X.iloc[te]) |
| yb = label_binarize(ytrue, classes=list(range(len(classes_present)))) |
| aucs.append(float(roc_auc_score(yb, proba, average="macro", multi_class="ovr"))) |
| except Exception: |
| pass |
| out[name] = { |
| "accuracy_mean": float(np.mean(accs)), "accuracy_std": float(np.std(accs)), |
| "f1_per_class": {c: float(np.mean(v)) for c, v in f1s.items()}, |
| "auc_roc_macro_ovr": float(np.mean(aucs)) if aucs else None, |
| "n_splits": n_splits, |
| } |
| return out, classes_present |
|
|
|
|
| def _clone(m): |
| from sklearn.base import clone |
| return clone(m) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--train", default="data/training_data.csv") |
| ap.add_argument("--outdir", default=".") |
| ap.add_argument("--n-splits", type=int, default=5) |
| ap.add_argument("--upload", default="", help="gs:// bucket prefix to upload to (optional)") |
| ap.add_argument("--force", action="store_true", help="train even if class counts are below the guard") |
| args = ap.parse_args() |
|
|
| df = pd.read_csv(args.train) |
| X, y, feats = _xy(df) |
| counts = Counter(y) |
| print(f"training rows: {len(X)} | features: {len(feats)} | class counts: {dict(counts)}") |
|
|
| |
| too_small = [c for c in CLASSES if counts.get(c, 0) < MIN_PER_CLASS] |
| n_splits = args.n_splits |
| if too_small: |
| msg = (f"\n*** EVIDENCE-THIN WARNING ***\n" |
| f"Classes below MIN_PER_CLASS={MIN_PER_CLASS}: {too_small}.\n" |
| f"5-fold CV on this set produces indicative, NOT validating, metrics.\n" |
| f"Report all numbers as 'baseline classifier, AF+categorical features, " |
| f"small-N (illustrative)'.\n") |
| print(msg) |
| n_splits = max(2, min(n_splits, min(counts.values()))) |
| if not args.force: |
| print("Refusing to silently emit CV metrics. Re-run with --force to proceed " |
| "(metrics will carry the small-N caveat in cv_metrics.json).") |
| sys.exit(2) |
|
|
| metrics, classes_present = cross_validate(X, y, n_splits=n_splits) |
| payload = { |
| "model_label": "baseline classifier, AF+categorical features (no SIFT/PolyPhen/CADD)", |
| "feature_columns": feats, |
| "classes": classes_present, |
| "class_counts": dict(counts), |
| "evidence_caveat": ("small-N (illustrative)" if too_small else "ok"), |
| "cv": metrics, |
| } |
| import os |
| os.makedirs(f"{args.outdir}/results", exist_ok=True) |
| os.makedirs(f"{args.outdir}/models", exist_ok=True) |
| with open(f"{args.outdir}/results/cv_metrics.json", "w") as f: |
| json.dump(payload, f, indent=2) |
| print("wrote results/cv_metrics.json") |
|
|
| |
| import joblib |
| for name, model in _models().items(): |
| m = _clone(model) |
| m.fit(X, y.cat.codes) |
| joblib.dump({"model": m, "features": feats, "classes": classes_present}, |
| f"{args.outdir}/models/{name}_model.pkl") |
| print(f"saved models/{name}_model.pkl") |
|
|
| if args.upload: |
| from src.gcs_io import upload_dir |
| upload_dir(f"{args.outdir}/models", f"{args.upload}/models") |
| upload_dir(f"{args.outdir}/results", f"{args.upload}/results") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|