"""Inference + ranking for a candidate DPYD variant list (e.g. Scaria et al. 2025). Loads the 3 fitted models, scores each candidate variant, and emits a ranked table: variant | rsID | gnomad_sas_af | predicted_class | confidence | model_agreement - confidence = mean of the per-model max-class probability. - model_agreement = True iff all 3 models predict the same class. - Variants with no gnomAD SAS data are flagged as coverage gaps (kept, not dropped). IMPORTANT: the Scaria et al. (2025) variant list is NOT bundled here — it was not resolvable from the task spec (the ClinVar query "DPYD[gene] AND Indian [title]" returns nothing meaningful, and the rsIDs were "to be resolved"). Provide a real list via --variants . Until then this module is wired and tested in shape only. Run: python -m src.infer --variants data/scaria_variants.csv --models models --out results/scaria_variant_rankings.csv """ from __future__ import annotations import argparse, csv, os import numpy as np import pandas as pd from src.features import build as build_features # reuse feature transforms from src.fetch_gnomad import query_variant, _pop_af def _load_models(models_dir: str): import joblib out = {} for name in ("rf", "xgb", "lgbm"): p = os.path.join(models_dir, f"{name}_model.pkl") if os.path.exists(p): out[name] = joblib.load(p) if not out: raise SystemExit(f"no models found in {models_dir} — train first") return out def _featurize_candidates(variants_csv: str) -> pd.DataFrame: """variants_csv must have at least 'rsid'; optional 'variant_id','gnomad_sas_af'. Missing gnomAD AF is fetched live per-variant.""" df = pd.read_csv(variants_csv, dtype=str).fillna("") rows = [] for _, r in df.iterrows(): vid = r.get("variant_id", "") sas = r.get("gnomad_sas_af", "") gaf = r.get("gnomad_global_af", "") in_g = 1 if vid and (sas == "" or gaf == ""): v = query_variant(vid) if v in (None, "ERR"): in_g = 0 else: g, e = v.get("genome"), v.get("exome") gaf = (g or {}).get("af") or (e or {}).get("af") if (g or e) else "" sas = _pop_af(g, "sas") if sas is None: sas = _pop_af(e, "sas") rows.append({ "rsid": r.get("rsid", ""), "variant_id": vid, "gnomad_global_af": gaf, "gnomad_sas_af": sas, "in_gnomad": in_g, "consequence": r.get("consequence", "other"), "clnsig_norm": r.get("clnsig_norm", ""), "ref": r.get("ref", ""), "alt": r.get("alt", ""), }) out = pd.DataFrame(rows) out["coverage_gap"] = ((out["gnomad_sas_af"].isna()) | (out["gnomad_sas_af"].astype(str) == "") | (out["in_gnomad"] == 0)) return out def _align(X_raw: pd.DataFrame, feats: list[str]) -> pd.DataFrame: import numpy as np num = ["gnomad_global_af", "gnomad_sas_af", "in_gnomad"] X = X_raw.copy() for c in num: X[c] = pd.to_numeric(X[c], errors="coerce").fillna(0.0) X["log10_gnomad_global_af"] = np.log10(X["gnomad_global_af"].clip(lower=0) + 1e-7) X["log10_gnomad_sas_af"] = np.log10(X["gnomad_sas_af"].clip(lower=0) + 1e-7) X["sas_enriched"] = ((X["gnomad_sas_af"] > X["gnomad_global_af"]) & (X["gnomad_sas_af"] > 0)).astype(int) X["is_indel"] = ((X["ref"].str.len() != 1) | (X["alt"].str.len() != 1)).astype(int) X = pd.concat([X, pd.get_dummies(X[["consequence", "clnsig_norm"]].astype(str), prefix=["consequence", "clnsig_norm"])], axis=1) for f in feats: if f not in X.columns: X[f] = 0 return X[feats] def rank(variants_csv: str, models_dir="models", out="results/scaria_variant_rankings.csv"): models = _load_models(models_dir) cand = _featurize_candidates(variants_csv) feats = next(iter(models.values()))["features"] classes = next(iter(models.values()))["classes"] X = _align(cand, feats) per_model_pred, per_model_conf = {}, {} for name, bundle in models.items(): m = bundle["model"] proba = m.predict_proba(X) idx = proba.argmax(axis=1) per_model_pred[name] = [classes[i] for i in idx] per_model_conf[name] = proba.max(axis=1) n = len(cand) results = [] for i in range(n): preds = {k: per_model_pred[k][i] for k in models} confs = [per_model_conf[k][i] for k in models] agree = len(set(preds.values())) == 1 # consensus = majority vote; confidence = mean max-prob vote = max(set(preds.values()), key=list(preds.values()).count) results.append({ "variant": cand.iloc[i]["variant_id"] or cand.iloc[i]["rsid"], "rsID": cand.iloc[i]["rsid"], "gnomad_sas_af": cand.iloc[i]["gnomad_sas_af"], "predicted_class": vote, "confidence": round(float(np.mean(confs)), 4), "model_agreement": bool(agree), "per_model": ";".join(f"{k}={v}" for k, v in preds.items()), "coverage_gap": bool(cand.iloc[i]["coverage_gap"]), }) res = pd.DataFrame(results).sort_values("confidence", ascending=False) os.makedirs(os.path.dirname(out), exist_ok=True) res.to_csv(out, index=False) gaps = res[res["coverage_gap"]]["rsID"].tolist() print(f"ranked {len(res)} variants -> {out}") print(f"all-3-agree: {int(res['model_agreement'].sum())} | coverage gaps (no gnomAD SAS): {len(gaps)}") if gaps: print("coverage-gap rsIDs:", gaps) return res def main(): ap = argparse.ArgumentParser() ap.add_argument("--variants", required=True, help="CSV with at least an 'rsid' column") ap.add_argument("--models", default="models") ap.add_argument("--out", default="results/scaria_variant_rankings.csv") args = ap.parse_args() rank(args.variants, args.models, args.out) if __name__ == "__main__": main()