"""Pull a held-out evaluation set of ClinVar 4-star expert-panel variants. Used by `backend/tests/test_known_variants.py` to measure end-to-end classification concordance vs. the ClinVar gold standard. NCBI rate limit: 3 req/s without key, 10 req/s with NCBI_API_KEY in `.env`. Usage ----- python -m scripts.seed_eval_set --n 100 --gene BRCA1 --gene TSC2 --gene MLH1 By default writes to `backend/tests/fixtures/clinvar_validation_set.json`. """ from __future__ import annotations import argparse import asyncio import gzip import json import logging import sys from pathlib import Path from typing import Any import httpx from backend.app.config import get_settings logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("seed_eval_set") logging.getLogger("httpx").setLevel(logging.WARNING) EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" DEFAULT_GENES = [ "BRCA1", "BRCA2", "TSC2", "MLH1", "MSH2", "MSH6", "PMS2", "PALB2", "ATM", "CHEK2", "TP53", "PTEN", "CDH1", "RB1", "VHL", ] EXPERT_REVIEW_STATUSES = {"reviewed_by_expert_panel", "practice_guideline"} VALID_CLASSIFICATIONS = { "Pathogenic", "Likely_pathogenic", "Uncertain_significance", "Likely_benign", "Benign", } def _params(settings: Any, **extra: Any) -> dict[str, Any]: p = {"tool": "VariantLens-eval", "email": settings.ncbi_email} if settings.ncbi_api_key: p["api_key"] = settings.ncbi_api_key return {**p, **extra} async def _search_expert_panel( client: httpx.AsyncClient, settings: Any, gene: str, retmax: int, ) -> list[str]: term = f"{gene}[Gene Name]" r = await client.get( f"{EUTILS}/esearch.fcgi", params=_params(settings, db="clinvar", term=term, retmax=retmax, retmode="json"), ) r.raise_for_status() ids = r.json().get("esearchresult", {}).get("idlist", []) return [str(i) for i in ids] async def _fetch_summary(client: httpx.AsyncClient, settings: Any, ids: list[str]) -> list[dict[str, Any]]: if not ids: return [] r = await client.get( f"{EUTILS}/esummary.fcgi", params=_params(settings, db="clinvar", id=",".join(ids), retmode="json"), ) r.raise_for_status() payload = r.json().get("result", {}) out: list[dict[str, Any]] = [] for vid in ids: item = payload.get(vid) if not item: continue cls = (item.get("germline_classification", {}) or {}).get("description") or item.get("clinical_significance", {}).get("description") title = item.get("title") or "" review = (item.get("germline_classification", {}) or {}).get("review_status", "") or item.get("clinical_significance", {}).get("review_status", "") if not cls: continue # Prefer entries with explicit canonical SPDI / HGVS in the title. out.append({ "variation_id": vid, "title": title, "expected_classification": cls, "review_status": review, "gene": item.get("genes", [{}])[0].get("symbol", "") if item.get("genes") else "", }) return out def _parse_info(info: str) -> dict[str, str]: parsed: dict[str, str] = {} for item in info.split(";"): if "=" not in item: continue key, value = item.split("=", 1) parsed[key] = value return parsed def _iter_local_expert_panel_ids(path: Path, genes: set[str] | None, limit: int) -> list[str]: out: list[str] = [] if not path.exists(): logger.warning("ClinVar VCF not found at %s; falling back to NCBI search", path) return out with gzip.open(path, "rt") as handle: for line in handle: if line.startswith("#"): continue fields = line.rstrip("\n").split("\t") if len(fields) < 8: continue info = _parse_info(fields[7]) review_status = info.get("CLNREVSTAT", "") classification = info.get("CLNSIG", "") geneinfo = info.get("GENEINFO", "") gene = geneinfo.split(":", 1)[0] if geneinfo else "" if genes and gene not in genes: continue if review_status not in EXPERT_REVIEW_STATUSES: continue if classification not in VALID_CLASSIFICATIONS: continue variation_id = fields[2] if variation_id == ".": continue out.append(variation_id) if len(out) >= limit: break return out async def collect_from_local_vcf(genes: list[str], n_total: int) -> list[dict[str, Any]]: settings = get_settings() ids = _iter_local_expert_panel_ids(settings.clinvar_vcf_path, set(genes) if genes else None, n_total) if not ids: return [] logger.info("local ClinVar VCF: found %d expert-panel/practice-guideline IDs", len(ids)) out: list[dict[str, Any]] = [] async with httpx.AsyncClient(timeout=30.0) as client: for i in range(0, len(ids), 200): rows = await _fetch_summary(client, settings, ids[i: i + 200]) out.extend(rows) return out[:n_total] async def collect( genes: list[str], n_total: int, per_gene: int, restrict_local_to_genes: bool = False, ) -> list[dict[str, Any]]: settings = get_settings() local_rows = await collect_from_local_vcf(genes if restrict_local_to_genes else [], n_total) if local_rows: return local_rows out: list[dict[str, Any]] = [] async with httpx.AsyncClient(timeout=30.0) as client: for gene in genes: try: ids = await _search_expert_panel(client, settings, gene, retmax=per_gene * 3) except httpx.HTTPError as e: logger.warning("search failed for %s: %s", gene, e) continue if not ids: continue try: rows = await _fetch_summary(client, settings, ids[: per_gene * 3]) except httpx.HTTPError as e: logger.warning("summary failed for %s: %s", gene, e) continue kept = [ row for row in rows if row.get("review_status", "").replace(" ", "_") in EXPERT_REVIEW_STATUSES ][:per_gene] logger.info("%s: kept %d/%d", gene, len(kept), len(rows)) out.extend(kept) if len(out) >= n_total: break return out[:n_total] def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--n", type=int, default=100, help="Total variants to collect") parser.add_argument("--per-gene", type=int, default=8, help="Cap per gene") parser.add_argument("--gene", action="append", help="Override default gene list (repeatable)") parser.add_argument( "--out", type=Path, default=Path("backend/tests/fixtures/clinvar_validation_set.json"), ) args = parser.parse_args() genes = args.gene or DEFAULT_GENES rows = asyncio.run(collect(genes, args.n, args.per_gene, restrict_local_to_genes=bool(args.gene))) args.out.parent.mkdir(parents=True, exist_ok=True) args.out.write_text(json.dumps(rows, indent=2) + "\n") logger.info("wrote %d entries to %s", len(rows), args.out) return 0 if __name__ == "__main__": sys.exit(main())