Spaces:
Sleeping
Sleeping
| """Pull a held-out evaluation set of ClinVar 4-star expert-panel variants. | |
| Used by `backend/tests/test_known_variants.py` to measure end-to-end | |
| classification concordance vs. the ClinVar gold standard. | |
| NCBI rate limit: 3 req/s without key, 10 req/s with NCBI_API_KEY in `.env`. | |
| Usage | |
| ----- | |
| python -m scripts.seed_eval_set --n 100 --gene BRCA1 --gene TSC2 --gene MLH1 | |
| By default writes to `backend/tests/fixtures/clinvar_validation_set.json`. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import gzip | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| import httpx | |
| from backend.app.config import get_settings | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger("seed_eval_set") | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
| DEFAULT_GENES = [ | |
| "BRCA1", "BRCA2", "TSC2", "MLH1", "MSH2", "MSH6", "PMS2", | |
| "PALB2", "ATM", "CHEK2", "TP53", "PTEN", "CDH1", "RB1", "VHL", | |
| ] | |
| EXPERT_REVIEW_STATUSES = {"reviewed_by_expert_panel", "practice_guideline"} | |
| VALID_CLASSIFICATIONS = { | |
| "Pathogenic", | |
| "Likely_pathogenic", | |
| "Uncertain_significance", | |
| "Likely_benign", | |
| "Benign", | |
| } | |
| def _params(settings: Any, **extra: Any) -> dict[str, Any]: | |
| p = {"tool": "VariantLens-eval", "email": settings.ncbi_email} | |
| if settings.ncbi_api_key: | |
| p["api_key"] = settings.ncbi_api_key | |
| return {**p, **extra} | |
| async def _search_expert_panel( | |
| client: httpx.AsyncClient, | |
| settings: Any, | |
| gene: str, | |
| retmax: int, | |
| ) -> list[str]: | |
| term = f"{gene}[Gene Name]" | |
| r = await client.get( | |
| f"{EUTILS}/esearch.fcgi", | |
| params=_params(settings, db="clinvar", term=term, retmax=retmax, retmode="json"), | |
| ) | |
| r.raise_for_status() | |
| ids = r.json().get("esearchresult", {}).get("idlist", []) | |
| return [str(i) for i in ids] | |
| async def _fetch_summary(client: httpx.AsyncClient, settings: Any, ids: list[str]) -> list[dict[str, Any]]: | |
| if not ids: | |
| return [] | |
| r = await client.get( | |
| f"{EUTILS}/esummary.fcgi", | |
| params=_params(settings, db="clinvar", id=",".join(ids), retmode="json"), | |
| ) | |
| r.raise_for_status() | |
| payload = r.json().get("result", {}) | |
| out: list[dict[str, Any]] = [] | |
| for vid in ids: | |
| item = payload.get(vid) | |
| if not item: | |
| continue | |
| cls = (item.get("germline_classification", {}) or {}).get("description") or item.get("clinical_significance", {}).get("description") | |
| title = item.get("title") or "" | |
| review = (item.get("germline_classification", {}) or {}).get("review_status", "") or item.get("clinical_significance", {}).get("review_status", "") | |
| if not cls: | |
| continue | |
| # Prefer entries with explicit canonical SPDI / HGVS in the title. | |
| out.append({ | |
| "variation_id": vid, | |
| "title": title, | |
| "expected_classification": cls, | |
| "review_status": review, | |
| "gene": item.get("genes", [{}])[0].get("symbol", "") if item.get("genes") else "", | |
| }) | |
| return out | |
| def _parse_info(info: str) -> dict[str, str]: | |
| parsed: dict[str, str] = {} | |
| for item in info.split(";"): | |
| if "=" not in item: | |
| continue | |
| key, value = item.split("=", 1) | |
| parsed[key] = value | |
| return parsed | |
| def _iter_local_expert_panel_ids(path: Path, genes: set[str] | None, limit: int) -> list[str]: | |
| out: list[str] = [] | |
| if not path.exists(): | |
| logger.warning("ClinVar VCF not found at %s; falling back to NCBI search", path) | |
| return out | |
| with gzip.open(path, "rt") as handle: | |
| for line in handle: | |
| if line.startswith("#"): | |
| continue | |
| fields = line.rstrip("\n").split("\t") | |
| if len(fields) < 8: | |
| continue | |
| info = _parse_info(fields[7]) | |
| review_status = info.get("CLNREVSTAT", "") | |
| classification = info.get("CLNSIG", "") | |
| geneinfo = info.get("GENEINFO", "") | |
| gene = geneinfo.split(":", 1)[0] if geneinfo else "" | |
| if genes and gene not in genes: | |
| continue | |
| if review_status not in EXPERT_REVIEW_STATUSES: | |
| continue | |
| if classification not in VALID_CLASSIFICATIONS: | |
| continue | |
| variation_id = fields[2] | |
| if variation_id == ".": | |
| continue | |
| out.append(variation_id) | |
| if len(out) >= limit: | |
| break | |
| return out | |
| async def collect_from_local_vcf(genes: list[str], n_total: int) -> list[dict[str, Any]]: | |
| settings = get_settings() | |
| ids = _iter_local_expert_panel_ids(settings.clinvar_vcf_path, set(genes) if genes else None, n_total) | |
| if not ids: | |
| return [] | |
| logger.info("local ClinVar VCF: found %d expert-panel/practice-guideline IDs", len(ids)) | |
| out: list[dict[str, Any]] = [] | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| for i in range(0, len(ids), 200): | |
| rows = await _fetch_summary(client, settings, ids[i: i + 200]) | |
| out.extend(rows) | |
| return out[:n_total] | |
| async def collect( | |
| genes: list[str], | |
| n_total: int, | |
| per_gene: int, | |
| restrict_local_to_genes: bool = False, | |
| ) -> list[dict[str, Any]]: | |
| settings = get_settings() | |
| local_rows = await collect_from_local_vcf(genes if restrict_local_to_genes else [], n_total) | |
| if local_rows: | |
| return local_rows | |
| out: list[dict[str, Any]] = [] | |
| async with httpx.AsyncClient(timeout=30.0) as client: | |
| for gene in genes: | |
| try: | |
| ids = await _search_expert_panel(client, settings, gene, retmax=per_gene * 3) | |
| except httpx.HTTPError as e: | |
| logger.warning("search failed for %s: %s", gene, e) | |
| continue | |
| if not ids: | |
| continue | |
| try: | |
| rows = await _fetch_summary(client, settings, ids[: per_gene * 3]) | |
| except httpx.HTTPError as e: | |
| logger.warning("summary failed for %s: %s", gene, e) | |
| continue | |
| kept = [ | |
| row for row in rows | |
| if row.get("review_status", "").replace(" ", "_") in EXPERT_REVIEW_STATUSES | |
| ][:per_gene] | |
| logger.info("%s: kept %d/%d", gene, len(kept), len(rows)) | |
| out.extend(kept) | |
| if len(out) >= n_total: | |
| break | |
| return out[:n_total] | |
| def main() -> int: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--n", type=int, default=100, help="Total variants to collect") | |
| parser.add_argument("--per-gene", type=int, default=8, help="Cap per gene") | |
| parser.add_argument("--gene", action="append", help="Override default gene list (repeatable)") | |
| parser.add_argument( | |
| "--out", | |
| type=Path, | |
| default=Path("backend/tests/fixtures/clinvar_validation_set.json"), | |
| ) | |
| args = parser.parse_args() | |
| genes = args.gene or DEFAULT_GENES | |
| rows = asyncio.run(collect(genes, args.n, args.per_gene, restrict_local_to_genes=bool(args.gene))) | |
| args.out.parent.mkdir(parents=True, exist_ok=True) | |
| args.out.write_text(json.dumps(rows, indent=2) + "\n") | |
| logger.info("wrote %d entries to %s", len(rows), args.out) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |