File size: 7,462 Bytes
3e219fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""Pull a held-out evaluation set of ClinVar 4-star expert-panel variants.

Used by `backend/tests/test_known_variants.py` to measure end-to-end
classification concordance vs. the ClinVar gold standard.

NCBI rate limit: 3 req/s without key, 10 req/s with NCBI_API_KEY in `.env`.

Usage
-----
    python -m scripts.seed_eval_set --n 100 --gene BRCA1 --gene TSC2 --gene MLH1

By default writes to `backend/tests/fixtures/clinvar_validation_set.json`.
"""
from __future__ import annotations

import argparse
import asyncio
import gzip
import json
import logging
import sys
from pathlib import Path
from typing import Any

import httpx

from backend.app.config import get_settings

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("seed_eval_set")
logging.getLogger("httpx").setLevel(logging.WARNING)

EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

DEFAULT_GENES = [
    "BRCA1", "BRCA2", "TSC2", "MLH1", "MSH2", "MSH6", "PMS2",
    "PALB2", "ATM", "CHEK2", "TP53", "PTEN", "CDH1", "RB1", "VHL",
]

EXPERT_REVIEW_STATUSES = {"reviewed_by_expert_panel", "practice_guideline"}
VALID_CLASSIFICATIONS = {
    "Pathogenic",
    "Likely_pathogenic",
    "Uncertain_significance",
    "Likely_benign",
    "Benign",
}


def _params(settings: Any, **extra: Any) -> dict[str, Any]:
    p = {"tool": "VariantLens-eval", "email": settings.ncbi_email}
    if settings.ncbi_api_key:
        p["api_key"] = settings.ncbi_api_key
    return {**p, **extra}


async def _search_expert_panel(
    client: httpx.AsyncClient,
    settings: Any,
    gene: str,
    retmax: int,
) -> list[str]:
    term = f"{gene}[Gene Name]"
    r = await client.get(
        f"{EUTILS}/esearch.fcgi",
        params=_params(settings, db="clinvar", term=term, retmax=retmax, retmode="json"),
    )
    r.raise_for_status()
    ids = r.json().get("esearchresult", {}).get("idlist", [])
    return [str(i) for i in ids]


async def _fetch_summary(client: httpx.AsyncClient, settings: Any, ids: list[str]) -> list[dict[str, Any]]:
    if not ids:
        return []
    r = await client.get(
        f"{EUTILS}/esummary.fcgi",
        params=_params(settings, db="clinvar", id=",".join(ids), retmode="json"),
    )
    r.raise_for_status()
    payload = r.json().get("result", {})
    out: list[dict[str, Any]] = []
    for vid in ids:
        item = payload.get(vid)
        if not item:
            continue
        cls = (item.get("germline_classification", {}) or {}).get("description") or item.get("clinical_significance", {}).get("description")
        title = item.get("title") or ""
        review = (item.get("germline_classification", {}) or {}).get("review_status", "") or item.get("clinical_significance", {}).get("review_status", "")
        if not cls:
            continue
        # Prefer entries with explicit canonical SPDI / HGVS in the title.
        out.append({
            "variation_id": vid,
            "title": title,
            "expected_classification": cls,
            "review_status": review,
            "gene": item.get("genes", [{}])[0].get("symbol", "") if item.get("genes") else "",
        })
    return out


def _parse_info(info: str) -> dict[str, str]:
    parsed: dict[str, str] = {}
    for item in info.split(";"):
        if "=" not in item:
            continue
        key, value = item.split("=", 1)
        parsed[key] = value
    return parsed


def _iter_local_expert_panel_ids(path: Path, genes: set[str] | None, limit: int) -> list[str]:
    out: list[str] = []
    if not path.exists():
        logger.warning("ClinVar VCF not found at %s; falling back to NCBI search", path)
        return out

    with gzip.open(path, "rt") as handle:
        for line in handle:
            if line.startswith("#"):
                continue
            fields = line.rstrip("\n").split("\t")
            if len(fields) < 8:
                continue
            info = _parse_info(fields[7])
            review_status = info.get("CLNREVSTAT", "")
            classification = info.get("CLNSIG", "")
            geneinfo = info.get("GENEINFO", "")
            gene = geneinfo.split(":", 1)[0] if geneinfo else ""
            if genes and gene not in genes:
                continue
            if review_status not in EXPERT_REVIEW_STATUSES:
                continue
            if classification not in VALID_CLASSIFICATIONS:
                continue
            variation_id = fields[2]
            if variation_id == ".":
                continue
            out.append(variation_id)
            if len(out) >= limit:
                break
    return out


async def collect_from_local_vcf(genes: list[str], n_total: int) -> list[dict[str, Any]]:
    settings = get_settings()
    ids = _iter_local_expert_panel_ids(settings.clinvar_vcf_path, set(genes) if genes else None, n_total)
    if not ids:
        return []
    logger.info("local ClinVar VCF: found %d expert-panel/practice-guideline IDs", len(ids))
    out: list[dict[str, Any]] = []
    async with httpx.AsyncClient(timeout=30.0) as client:
        for i in range(0, len(ids), 200):
            rows = await _fetch_summary(client, settings, ids[i: i + 200])
            out.extend(rows)
    return out[:n_total]


async def collect(
    genes: list[str],
    n_total: int,
    per_gene: int,
    restrict_local_to_genes: bool = False,
) -> list[dict[str, Any]]:
    settings = get_settings()
    local_rows = await collect_from_local_vcf(genes if restrict_local_to_genes else [], n_total)
    if local_rows:
        return local_rows

    out: list[dict[str, Any]] = []
    async with httpx.AsyncClient(timeout=30.0) as client:
        for gene in genes:
            try:
                ids = await _search_expert_panel(client, settings, gene, retmax=per_gene * 3)
            except httpx.HTTPError as e:
                logger.warning("search failed for %s: %s", gene, e)
                continue
            if not ids:
                continue
            try:
                rows = await _fetch_summary(client, settings, ids[: per_gene * 3])
            except httpx.HTTPError as e:
                logger.warning("summary failed for %s: %s", gene, e)
                continue
            kept = [
                row for row in rows
                if row.get("review_status", "").replace(" ", "_") in EXPERT_REVIEW_STATUSES
            ][:per_gene]
            logger.info("%s: kept %d/%d", gene, len(kept), len(rows))
            out.extend(kept)
            if len(out) >= n_total:
                break
    return out[:n_total]


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--n", type=int, default=100, help="Total variants to collect")
    parser.add_argument("--per-gene", type=int, default=8, help="Cap per gene")
    parser.add_argument("--gene", action="append", help="Override default gene list (repeatable)")
    parser.add_argument(
        "--out",
        type=Path,
        default=Path("backend/tests/fixtures/clinvar_validation_set.json"),
    )
    args = parser.parse_args()

    genes = args.gene or DEFAULT_GENES
    rows = asyncio.run(collect(genes, args.n, args.per_gene, restrict_local_to_genes=bool(args.gene)))
    args.out.parent.mkdir(parents=True, exist_ok=True)
    args.out.write_text(json.dumps(rows, indent=2) + "\n")
    logger.info("wrote %d entries to %s", len(rows), args.out)
    return 0


if __name__ == "__main__":
    sys.exit(main())