varientlens / scripts /seed_eval_set.py
Codex
Initial VariantLens clinical readiness scaffold
3e219fa
"""Pull a held-out evaluation set of ClinVar 4-star expert-panel variants.
Used by `backend/tests/test_known_variants.py` to measure end-to-end
classification concordance vs. the ClinVar gold standard.
NCBI rate limit: 3 req/s without key, 10 req/s with NCBI_API_KEY in `.env`.
Usage
-----
python -m scripts.seed_eval_set --n 100 --gene BRCA1 --gene TSC2 --gene MLH1
By default writes to `backend/tests/fixtures/clinvar_validation_set.json`.
"""
from __future__ import annotations
import argparse
import asyncio
import gzip
import json
import logging
import sys
from pathlib import Path
from typing import Any
import httpx
from backend.app.config import get_settings
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("seed_eval_set")
logging.getLogger("httpx").setLevel(logging.WARNING)
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
DEFAULT_GENES = [
"BRCA1", "BRCA2", "TSC2", "MLH1", "MSH2", "MSH6", "PMS2",
"PALB2", "ATM", "CHEK2", "TP53", "PTEN", "CDH1", "RB1", "VHL",
]
EXPERT_REVIEW_STATUSES = {"reviewed_by_expert_panel", "practice_guideline"}
VALID_CLASSIFICATIONS = {
"Pathogenic",
"Likely_pathogenic",
"Uncertain_significance",
"Likely_benign",
"Benign",
}
def _params(settings: Any, **extra: Any) -> dict[str, Any]:
p = {"tool": "VariantLens-eval", "email": settings.ncbi_email}
if settings.ncbi_api_key:
p["api_key"] = settings.ncbi_api_key
return {**p, **extra}
async def _search_expert_panel(
client: httpx.AsyncClient,
settings: Any,
gene: str,
retmax: int,
) -> list[str]:
term = f"{gene}[Gene Name]"
r = await client.get(
f"{EUTILS}/esearch.fcgi",
params=_params(settings, db="clinvar", term=term, retmax=retmax, retmode="json"),
)
r.raise_for_status()
ids = r.json().get("esearchresult", {}).get("idlist", [])
return [str(i) for i in ids]
async def _fetch_summary(client: httpx.AsyncClient, settings: Any, ids: list[str]) -> list[dict[str, Any]]:
if not ids:
return []
r = await client.get(
f"{EUTILS}/esummary.fcgi",
params=_params(settings, db="clinvar", id=",".join(ids), retmode="json"),
)
r.raise_for_status()
payload = r.json().get("result", {})
out: list[dict[str, Any]] = []
for vid in ids:
item = payload.get(vid)
if not item:
continue
cls = (item.get("germline_classification", {}) or {}).get("description") or item.get("clinical_significance", {}).get("description")
title = item.get("title") or ""
review = (item.get("germline_classification", {}) or {}).get("review_status", "") or item.get("clinical_significance", {}).get("review_status", "")
if not cls:
continue
# Prefer entries with explicit canonical SPDI / HGVS in the title.
out.append({
"variation_id": vid,
"title": title,
"expected_classification": cls,
"review_status": review,
"gene": item.get("genes", [{}])[0].get("symbol", "") if item.get("genes") else "",
})
return out
def _parse_info(info: str) -> dict[str, str]:
parsed: dict[str, str] = {}
for item in info.split(";"):
if "=" not in item:
continue
key, value = item.split("=", 1)
parsed[key] = value
return parsed
def _iter_local_expert_panel_ids(path: Path, genes: set[str] | None, limit: int) -> list[str]:
out: list[str] = []
if not path.exists():
logger.warning("ClinVar VCF not found at %s; falling back to NCBI search", path)
return out
with gzip.open(path, "rt") as handle:
for line in handle:
if line.startswith("#"):
continue
fields = line.rstrip("\n").split("\t")
if len(fields) < 8:
continue
info = _parse_info(fields[7])
review_status = info.get("CLNREVSTAT", "")
classification = info.get("CLNSIG", "")
geneinfo = info.get("GENEINFO", "")
gene = geneinfo.split(":", 1)[0] if geneinfo else ""
if genes and gene not in genes:
continue
if review_status not in EXPERT_REVIEW_STATUSES:
continue
if classification not in VALID_CLASSIFICATIONS:
continue
variation_id = fields[2]
if variation_id == ".":
continue
out.append(variation_id)
if len(out) >= limit:
break
return out
async def collect_from_local_vcf(genes: list[str], n_total: int) -> list[dict[str, Any]]:
settings = get_settings()
ids = _iter_local_expert_panel_ids(settings.clinvar_vcf_path, set(genes) if genes else None, n_total)
if not ids:
return []
logger.info("local ClinVar VCF: found %d expert-panel/practice-guideline IDs", len(ids))
out: list[dict[str, Any]] = []
async with httpx.AsyncClient(timeout=30.0) as client:
for i in range(0, len(ids), 200):
rows = await _fetch_summary(client, settings, ids[i: i + 200])
out.extend(rows)
return out[:n_total]
async def collect(
genes: list[str],
n_total: int,
per_gene: int,
restrict_local_to_genes: bool = False,
) -> list[dict[str, Any]]:
settings = get_settings()
local_rows = await collect_from_local_vcf(genes if restrict_local_to_genes else [], n_total)
if local_rows:
return local_rows
out: list[dict[str, Any]] = []
async with httpx.AsyncClient(timeout=30.0) as client:
for gene in genes:
try:
ids = await _search_expert_panel(client, settings, gene, retmax=per_gene * 3)
except httpx.HTTPError as e:
logger.warning("search failed for %s: %s", gene, e)
continue
if not ids:
continue
try:
rows = await _fetch_summary(client, settings, ids[: per_gene * 3])
except httpx.HTTPError as e:
logger.warning("summary failed for %s: %s", gene, e)
continue
kept = [
row for row in rows
if row.get("review_status", "").replace(" ", "_") in EXPERT_REVIEW_STATUSES
][:per_gene]
logger.info("%s: kept %d/%d", gene, len(kept), len(rows))
out.extend(kept)
if len(out) >= n_total:
break
return out[:n_total]
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=100, help="Total variants to collect")
parser.add_argument("--per-gene", type=int, default=8, help="Cap per gene")
parser.add_argument("--gene", action="append", help="Override default gene list (repeatable)")
parser.add_argument(
"--out",
type=Path,
default=Path("backend/tests/fixtures/clinvar_validation_set.json"),
)
args = parser.parse_args()
genes = args.gene or DEFAULT_GENES
rows = asyncio.run(collect(genes, args.n, args.per_gene, restrict_local_to_genes=bool(args.gene)))
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps(rows, indent=2) + "\n")
logger.info("wrote %d entries to %s", len(rows), args.out)
return 0
if __name__ == "__main__":
sys.exit(main())