dpyd-classifier / src /infer.py
abhimanyu12's picture
Upload folder using huggingface_hub
2ea06dc verified
Raw
History Blame Contribute Delete
6.1 kB
"""Inference + ranking for a candidate DPYD variant list (e.g. Scaria et al. 2025).
Loads the 3 fitted models, scores each candidate variant, and emits a ranked
table:
variant | rsID | gnomad_sas_af | predicted_class | confidence | model_agreement
- confidence = mean of the per-model max-class probability.
- model_agreement = True iff all 3 models predict the same class.
- Variants with no gnomAD SAS data are flagged as coverage gaps (kept, not dropped).
IMPORTANT: the Scaria et al. (2025) variant list is NOT bundled here — it was
not resolvable from the task spec (the ClinVar query "DPYD[gene] AND Indian
[title]" returns nothing meaningful, and the rsIDs were "to be resolved").
Provide a real list via --variants <csv with columns: rsid[,variant_id]>.
Until then this module is wired and tested in shape only.
Run: python -m src.infer --variants data/scaria_variants.csv --models models --out results/scaria_variant_rankings.csv
"""
from __future__ import annotations
import argparse, csv, os
import numpy as np
import pandas as pd
from src.features import build as build_features # reuse feature transforms
from src.fetch_gnomad import query_variant, _pop_af
def _load_models(models_dir: str):
import joblib
out = {}
for name in ("rf", "xgb", "lgbm"):
p = os.path.join(models_dir, f"{name}_model.pkl")
if os.path.exists(p):
out[name] = joblib.load(p)
if not out:
raise SystemExit(f"no models found in {models_dir} — train first")
return out
def _featurize_candidates(variants_csv: str) -> pd.DataFrame:
"""variants_csv must have at least 'rsid'; optional 'variant_id','gnomad_sas_af'.
Missing gnomAD AF is fetched live per-variant."""
df = pd.read_csv(variants_csv, dtype=str).fillna("")
rows = []
for _, r in df.iterrows():
vid = r.get("variant_id", "")
sas = r.get("gnomad_sas_af", "")
gaf = r.get("gnomad_global_af", "")
in_g = 1
if vid and (sas == "" or gaf == ""):
v = query_variant(vid)
if v in (None, "ERR"):
in_g = 0
else:
g, e = v.get("genome"), v.get("exome")
gaf = (g or {}).get("af") or (e or {}).get("af") if (g or e) else ""
sas = _pop_af(g, "sas")
if sas is None:
sas = _pop_af(e, "sas")
rows.append({
"rsid": r.get("rsid", ""), "variant_id": vid,
"gnomad_global_af": gaf, "gnomad_sas_af": sas, "in_gnomad": in_g,
"consequence": r.get("consequence", "other"),
"clnsig_norm": r.get("clnsig_norm", ""),
"ref": r.get("ref", ""), "alt": r.get("alt", ""),
})
out = pd.DataFrame(rows)
out["coverage_gap"] = ((out["gnomad_sas_af"].isna()) |
(out["gnomad_sas_af"].astype(str) == "") |
(out["in_gnomad"] == 0))
return out
def _align(X_raw: pd.DataFrame, feats: list[str]) -> pd.DataFrame:
import numpy as np
num = ["gnomad_global_af", "gnomad_sas_af", "in_gnomad"]
X = X_raw.copy()
for c in num:
X[c] = pd.to_numeric(X[c], errors="coerce").fillna(0.0)
X["log10_gnomad_global_af"] = np.log10(X["gnomad_global_af"].clip(lower=0) + 1e-7)
X["log10_gnomad_sas_af"] = np.log10(X["gnomad_sas_af"].clip(lower=0) + 1e-7)
X["sas_enriched"] = ((X["gnomad_sas_af"] > X["gnomad_global_af"]) & (X["gnomad_sas_af"] > 0)).astype(int)
X["is_indel"] = ((X["ref"].str.len() != 1) | (X["alt"].str.len() != 1)).astype(int)
X = pd.concat([X, pd.get_dummies(X[["consequence", "clnsig_norm"]].astype(str),
prefix=["consequence", "clnsig_norm"])], axis=1)
for f in feats:
if f not in X.columns:
X[f] = 0
return X[feats]
def rank(variants_csv: str, models_dir="models", out="results/scaria_variant_rankings.csv"):
models = _load_models(models_dir)
cand = _featurize_candidates(variants_csv)
feats = next(iter(models.values()))["features"]
classes = next(iter(models.values()))["classes"]
X = _align(cand, feats)
per_model_pred, per_model_conf = {}, {}
for name, bundle in models.items():
m = bundle["model"]
proba = m.predict_proba(X)
idx = proba.argmax(axis=1)
per_model_pred[name] = [classes[i] for i in idx]
per_model_conf[name] = proba.max(axis=1)
n = len(cand)
results = []
for i in range(n):
preds = {k: per_model_pred[k][i] for k in models}
confs = [per_model_conf[k][i] for k in models]
agree = len(set(preds.values())) == 1
# consensus = majority vote; confidence = mean max-prob
vote = max(set(preds.values()), key=list(preds.values()).count)
results.append({
"variant": cand.iloc[i]["variant_id"] or cand.iloc[i]["rsid"],
"rsID": cand.iloc[i]["rsid"],
"gnomad_sas_af": cand.iloc[i]["gnomad_sas_af"],
"predicted_class": vote,
"confidence": round(float(np.mean(confs)), 4),
"model_agreement": bool(agree),
"per_model": ";".join(f"{k}={v}" for k, v in preds.items()),
"coverage_gap": bool(cand.iloc[i]["coverage_gap"]),
})
res = pd.DataFrame(results).sort_values("confidence", ascending=False)
os.makedirs(os.path.dirname(out), exist_ok=True)
res.to_csv(out, index=False)
gaps = res[res["coverage_gap"]]["rsID"].tolist()
print(f"ranked {len(res)} variants -> {out}")
print(f"all-3-agree: {int(res['model_agreement'].sum())} | coverage gaps (no gnomAD SAS): {len(gaps)}")
if gaps:
print("coverage-gap rsIDs:", gaps)
return res
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--variants", required=True, help="CSV with at least an 'rsid' column")
ap.add_argument("--models", default="models")
ap.add_argument("--out", default="results/scaria_variant_rankings.csv")
args = ap.parse_args()
rank(args.variants, args.models, args.out)
if __name__ == "__main__":
main()