| """Inference + ranking for a candidate DPYD variant list (e.g. Scaria et al. 2025). |
| |
| Loads the 3 fitted models, scores each candidate variant, and emits a ranked |
| table: |
| variant | rsID | gnomad_sas_af | predicted_class | confidence | model_agreement |
| |
| - confidence = mean of the per-model max-class probability. |
| - model_agreement = True iff all 3 models predict the same class. |
| - Variants with no gnomAD SAS data are flagged as coverage gaps (kept, not dropped). |
| |
| IMPORTANT: the Scaria et al. (2025) variant list is NOT bundled here — it was |
| not resolvable from the task spec (the ClinVar query "DPYD[gene] AND Indian |
| [title]" returns nothing meaningful, and the rsIDs were "to be resolved"). |
| Provide a real list via --variants <csv with columns: rsid[,variant_id]>. |
| Until then this module is wired and tested in shape only. |
| |
| Run: python -m src.infer --variants data/scaria_variants.csv --models models --out results/scaria_variant_rankings.csv |
| """ |
| from __future__ import annotations |
| import argparse, csv, os |
| import numpy as np |
| import pandas as pd |
|
|
| from src.features import build as build_features |
| from src.fetch_gnomad import query_variant, _pop_af |
|
|
|
|
| def _load_models(models_dir: str): |
| import joblib |
| out = {} |
| for name in ("rf", "xgb", "lgbm"): |
| p = os.path.join(models_dir, f"{name}_model.pkl") |
| if os.path.exists(p): |
| out[name] = joblib.load(p) |
| if not out: |
| raise SystemExit(f"no models found in {models_dir} — train first") |
| return out |
|
|
|
|
| def _featurize_candidates(variants_csv: str) -> pd.DataFrame: |
| """variants_csv must have at least 'rsid'; optional 'variant_id','gnomad_sas_af'. |
| Missing gnomAD AF is fetched live per-variant.""" |
| df = pd.read_csv(variants_csv, dtype=str).fillna("") |
| rows = [] |
| for _, r in df.iterrows(): |
| vid = r.get("variant_id", "") |
| sas = r.get("gnomad_sas_af", "") |
| gaf = r.get("gnomad_global_af", "") |
| in_g = 1 |
| if vid and (sas == "" or gaf == ""): |
| v = query_variant(vid) |
| if v in (None, "ERR"): |
| in_g = 0 |
| else: |
| g, e = v.get("genome"), v.get("exome") |
| gaf = (g or {}).get("af") or (e or {}).get("af") if (g or e) else "" |
| sas = _pop_af(g, "sas") |
| if sas is None: |
| sas = _pop_af(e, "sas") |
| rows.append({ |
| "rsid": r.get("rsid", ""), "variant_id": vid, |
| "gnomad_global_af": gaf, "gnomad_sas_af": sas, "in_gnomad": in_g, |
| "consequence": r.get("consequence", "other"), |
| "clnsig_norm": r.get("clnsig_norm", ""), |
| "ref": r.get("ref", ""), "alt": r.get("alt", ""), |
| }) |
| out = pd.DataFrame(rows) |
| out["coverage_gap"] = ((out["gnomad_sas_af"].isna()) | |
| (out["gnomad_sas_af"].astype(str) == "") | |
| (out["in_gnomad"] == 0)) |
| return out |
|
|
|
|
| def _align(X_raw: pd.DataFrame, feats: list[str]) -> pd.DataFrame: |
| import numpy as np |
| num = ["gnomad_global_af", "gnomad_sas_af", "in_gnomad"] |
| X = X_raw.copy() |
| for c in num: |
| X[c] = pd.to_numeric(X[c], errors="coerce").fillna(0.0) |
| X["log10_gnomad_global_af"] = np.log10(X["gnomad_global_af"].clip(lower=0) + 1e-7) |
| X["log10_gnomad_sas_af"] = np.log10(X["gnomad_sas_af"].clip(lower=0) + 1e-7) |
| X["sas_enriched"] = ((X["gnomad_sas_af"] > X["gnomad_global_af"]) & (X["gnomad_sas_af"] > 0)).astype(int) |
| X["is_indel"] = ((X["ref"].str.len() != 1) | (X["alt"].str.len() != 1)).astype(int) |
| X = pd.concat([X, pd.get_dummies(X[["consequence", "clnsig_norm"]].astype(str), |
| prefix=["consequence", "clnsig_norm"])], axis=1) |
| for f in feats: |
| if f not in X.columns: |
| X[f] = 0 |
| return X[feats] |
|
|
|
|
| def rank(variants_csv: str, models_dir="models", out="results/scaria_variant_rankings.csv"): |
| models = _load_models(models_dir) |
| cand = _featurize_candidates(variants_csv) |
| feats = next(iter(models.values()))["features"] |
| classes = next(iter(models.values()))["classes"] |
| X = _align(cand, feats) |
|
|
| per_model_pred, per_model_conf = {}, {} |
| for name, bundle in models.items(): |
| m = bundle["model"] |
| proba = m.predict_proba(X) |
| idx = proba.argmax(axis=1) |
| per_model_pred[name] = [classes[i] for i in idx] |
| per_model_conf[name] = proba.max(axis=1) |
|
|
| n = len(cand) |
| results = [] |
| for i in range(n): |
| preds = {k: per_model_pred[k][i] for k in models} |
| confs = [per_model_conf[k][i] for k in models] |
| agree = len(set(preds.values())) == 1 |
| |
| vote = max(set(preds.values()), key=list(preds.values()).count) |
| results.append({ |
| "variant": cand.iloc[i]["variant_id"] or cand.iloc[i]["rsid"], |
| "rsID": cand.iloc[i]["rsid"], |
| "gnomad_sas_af": cand.iloc[i]["gnomad_sas_af"], |
| "predicted_class": vote, |
| "confidence": round(float(np.mean(confs)), 4), |
| "model_agreement": bool(agree), |
| "per_model": ";".join(f"{k}={v}" for k, v in preds.items()), |
| "coverage_gap": bool(cand.iloc[i]["coverage_gap"]), |
| }) |
| res = pd.DataFrame(results).sort_values("confidence", ascending=False) |
| os.makedirs(os.path.dirname(out), exist_ok=True) |
| res.to_csv(out, index=False) |
| gaps = res[res["coverage_gap"]]["rsID"].tolist() |
| print(f"ranked {len(res)} variants -> {out}") |
| print(f"all-3-agree: {int(res['model_agreement'].sum())} | coverage gaps (no gnomAD SAS): {len(gaps)}") |
| if gaps: |
| print("coverage-gap rsIDs:", gaps) |
| return res |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--variants", required=True, help="CSV with at least an 'rsid' column") |
| ap.add_argument("--models", default="models") |
| ap.add_argument("--out", default="results/scaria_variant_rankings.csv") |
| args = ap.parse_args() |
| rank(args.variants, args.models, args.out) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|