Spaces:
Sleeping
Sleeping
| """Cache-first, rate-limited gnomAD batch validator. | |
| Goal | |
| ---- | |
| Verify that every variant in a batch has a confident gnomAD lookup before | |
| classification — either from the on-disk SQLite cache (fast, deterministic) | |
| or freshly from the gnomAD GraphQL API. Reports: | |
| * cache hit rate | |
| * variants that 404'd in gnomAD (legitimately absent → PM2) | |
| * variants where the API returned an error (suspicious — needs retry) | |
| * variants where gnomAD's coverage warning is set (low coverage region) | |
| The script uses an asyncio Semaphore to cap concurrent requests at the | |
| gnomAD-recommended polite-use rate (default: 2 req/s). Running an | |
| overnight pre-warm against the full week's variant pull keeps the | |
| critical-path classification under 30s per variant in the lab workflow. | |
| Usage | |
| ----- | |
| python -m scripts.gnomad_batch_validate variants.tsv | |
| python -m scripts.gnomad_batch_validate --concurrency 4 --out report.json variants.tsv | |
| Input format (one variant per line, # comments allowed): | |
| chr1-12345-A-G | |
| NM_007294.4:c.5266dupC | |
| 17-43124017-C-A | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import re | |
| import sqlite3 | |
| import time | |
| from collections import Counter | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logger = logging.getLogger("gnomad_batch") | |
| VCF_PATTERN = re.compile(r"^(chr)?[\dXYM]+[-:]\d+[-:][ACGT]+[-:][ACGT]+$", re.IGNORECASE) | |
| class GnomadCheckResult: | |
| input_line: str | |
| variant_id: str | None | |
| cache_hit: bool | |
| api_called: bool | |
| overall_af: float | None | |
| homozygote_count: int | None | |
| coverage_warning: str | None | |
| elapsed_s: float | |
| error: str | None | |
| def _to_variant_id(raw: str) -> str | None: | |
| """Normalize input to gnomAD ID `chr-pos-ref-alt` (no `chr` prefix). | |
| Falls back to None for non-VCF inputs — caller can route those | |
| through the full pipeline normalization first. | |
| """ | |
| s = raw.strip() | |
| if VCF_PATTERN.match(s): | |
| parts = re.split(r"[-:]", s) | |
| chrom, pos, ref, alt = parts | |
| chrom = chrom.replace("chr", "").upper() | |
| return f"{chrom}-{pos}-{ref}-{alt}" | |
| return None | |
| async def _check_one( | |
| raw: str, | |
| client, # GnomADClient (lazy import below) | |
| sem: asyncio.Semaphore, | |
| cache_db: Path, | |
| ) -> GnomadCheckResult: | |
| started = time.time() | |
| variant_id = _to_variant_id(raw) | |
| if not variant_id: | |
| return GnomadCheckResult( | |
| input_line=raw, variant_id=None, cache_hit=False, api_called=False, | |
| overall_af=None, homozygote_count=None, coverage_warning=None, | |
| elapsed_s=round(time.time() - started, 3), | |
| error="non-VCF input — route through pipeline normalize first", | |
| ) | |
| # Cache probe up front | |
| cache_hit = False | |
| if cache_db.exists(): | |
| with sqlite3.connect(cache_db) as conn: | |
| row = conn.execute( | |
| "SELECT af, homozygotes, coverage_warning FROM gnomad_cache WHERE variant_id = ?", | |
| (variant_id,), | |
| ).fetchone() | |
| if row: | |
| af, hom, cov = row | |
| return GnomadCheckResult( | |
| input_line=raw, variant_id=variant_id, cache_hit=True, api_called=False, | |
| overall_af=af, homozygote_count=hom, coverage_warning=cov, | |
| elapsed_s=round(time.time() - started, 3), error=None, | |
| ) | |
| # API call (rate-limited) | |
| async with sem: | |
| try: | |
| freq = await client.lookup(variant_id) | |
| return GnomadCheckResult( | |
| input_line=raw, variant_id=variant_id, cache_hit=cache_hit, api_called=True, | |
| overall_af=freq.overall_af, homozygote_count=freq.homozygote_count, | |
| coverage_warning=freq.coverage_warning, | |
| elapsed_s=round(time.time() - started, 3), error=None, | |
| ) | |
| except Exception as e: | |
| return GnomadCheckResult( | |
| input_line=raw, variant_id=variant_id, cache_hit=False, api_called=True, | |
| overall_af=None, homozygote_count=None, coverage_warning=None, | |
| elapsed_s=round(time.time() - started, 3), error=str(e), | |
| ) | |
| async def run(inputs: list[str], concurrency: int) -> dict: | |
| from backend.app.config import get_settings | |
| from backend.app.services.gnomad import GnomADClient | |
| settings = get_settings() | |
| client = GnomADClient() | |
| sem = asyncio.Semaphore(concurrency) | |
| started = time.time() | |
| results = await asyncio.gather(*( | |
| _check_one(raw, client, sem, settings.gnomad_cache_db) for raw in inputs | |
| )) | |
| total = len(results) | |
| cached = sum(1 for r in results if r.cache_hit) | |
| api = sum(1 for r in results if r.api_called) | |
| errors = sum(1 for r in results if r.error) | |
| absent = sum(1 for r in results if r.coverage_warning == "absent from gnomAD") | |
| low_cov = sum(1 for r in results if r.coverage_warning and "low coverage" in (r.coverage_warning or "")) | |
| return { | |
| "summary": { | |
| "total": total, | |
| "cache_hits": cached, | |
| "cache_hit_rate": round(cached / total, 3) if total else 0, | |
| "api_calls": api, | |
| "errors": errors, | |
| "absent_in_gnomad": absent, | |
| "low_coverage": low_cov, | |
| "elapsed_s": round(time.time() - started, 1), | |
| "concurrency": concurrency, | |
| }, | |
| "results": [asdict(r) for r in results], | |
| } | |
| def main() -> int: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("input", type=Path, help="One variant per line (chr-pos-ref-alt)") | |
| parser.add_argument("--concurrency", type=int, default=2, help="Max parallel requests (gnomAD polite rate ~2/s)") | |
| parser.add_argument("--out", type=Path, default=Path("docs/gnomad_batch_report.json")) | |
| args = parser.parse_args() | |
| if not args.input.exists(): | |
| logger.error("input file not found: %s", args.input) | |
| return 1 | |
| raw_lines = [ | |
| ln.strip() for ln in args.input.read_text().splitlines() | |
| if ln.strip() and not ln.startswith("#") | |
| ] | |
| logger.info("validating %d variants with concurrency=%d", len(raw_lines), args.concurrency) | |
| report = asyncio.run(run(raw_lines, args.concurrency)) | |
| args.out.parent.mkdir(parents=True, exist_ok=True) | |
| args.out.write_text(json.dumps(report, indent=2) + "\n") | |
| s = report["summary"] | |
| logger.info( | |
| "done — %d cached / %d API / %d errors / %d absent in %.1fs (%.1f%% cache hit)", | |
| s["cache_hits"], s["api_calls"], s["errors"], s["absent_in_gnomad"], | |
| s["elapsed_s"], s["cache_hit_rate"] * 100, | |
| ) | |
| counter: Counter[str] = Counter() | |
| for r in report["results"]: | |
| if r["error"]: | |
| counter[r["error"][:60]] += 1 | |
| if counter: | |
| logger.warning("error breakdown:") | |
| for k, n in counter.most_common(5): | |
| logger.warning(" %d × %s", n, k) | |
| return 0 if report["summary"]["errors"] == 0 else 2 | |
| if __name__ == "__main__": | |
| import sys | |
| sys.exit(main()) | |