microbe-model / scripts /modal_kofam.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""KOfam scan on Modal CPU — fast parallel HMMER scan against KEGG-relevant KOs.
Same architecture as scripts/28_kofam_scan.py but each Modal container has its
own NCBI fetch IP (bypasses the rate limit that bottlenecks local parallelism)
and its own CPU resources. The 700 MB KOfam relevant library is bundled into
the Modal image so containers don't redownload.
Usage:
modal setup # one-time
modal run scripts/modal_kofam.py --limit 10 # smoke test
modal run scripts/modal_kofam.py # full corpus
Cost: ~$2-4 for 22K genomes at 4-CPU containers, 16-way parallel.
Resumes from data/kofam_hits.jsonl — already-processed accessions are skipped.
"""
from __future__ import annotations
import json
from pathlib import Path
import modal
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install([
"pyrodigal>=3.5",
"pyhmmer>=0.12",
"biopython>=1.83",
"requests>=2.32",
])
.add_local_file("data/kofam/kofam_relevant.hmm", "/root/kofam.hmm", copy=True)
.add_local_file("data/kofam/ko_thresholds.tsv", "/root/ko_thresholds.tsv", copy=True)
)
app = modal.App("microbe-kofam", image=image)
DATASETS_URL = "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{acc}/download"
VERSION_FALLBACKS = (".1", ".2", ".3", ".4")
EMPTY_ZIP_BYTES = 2_000
DEFAULT_EVALUE = 1e-5
def _has_version(accession: str) -> bool:
if "." not in accession:
return False
return accession.rsplit(".", 1)[-1].isdigit()
def _candidate_accessions(accession: str) -> list[str]:
if _has_version(accession):
return [accession]
return [accession + v for v in VERSION_FALLBACKS]
def _fetch_fasta_bytes(accession: str, ncbi_key: str | None) -> list[tuple[str, str]] | None:
import io, time, zipfile
import requests
rate = 0.1 if ncbi_key else 0.34
headers = {"Accept": "application/zip"}
if ncbi_key:
headers["api-key"] = ncbi_key
params = {"include_annotation_type": "GENOME_FASTA"}
for candidate in _candidate_accessions(accession):
zip_bytes: bytes | None = None
for attempt in range(3):
try:
time.sleep(rate)
resp = requests.get(
DATASETS_URL.format(acc=candidate),
params=params, headers=headers, timeout=120,
)
if resp.status_code == 404:
break
if resp.status_code in (429, 502, 503):
time.sleep(2 ** attempt)
continue
resp.raise_for_status()
except requests.RequestException:
if attempt == 2:
break
time.sleep(2 ** attempt)
continue
if len(resp.content) < EMPTY_ZIP_BYTES:
break
zip_bytes = resp.content
break
if zip_bytes is None:
continue
try:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
fasta_names = [n for n in zf.namelist() if n.endswith(".fna")]
if not fasta_names:
continue
with zf.open(fasta_names[0]) as src:
raw = src.read()
except zipfile.BadZipFile:
continue
return _parse_fasta(raw)
return None
def _parse_fasta(raw: bytes) -> list[tuple[str, str]]:
contigs: list[tuple[str, str]] = []
current_id: str | None = None
chunks: list[str] = []
for line in raw.splitlines():
if not line:
continue
if line.startswith(b">"):
if current_id is not None:
contigs.append((current_id, "".join(chunks).upper()))
current_id = line[1:].decode("ascii", errors="replace").split()[0]
chunks = []
else:
chunks.append(line.decode("ascii", errors="replace"))
if current_id is not None:
contigs.append((current_id, "".join(chunks).upper()))
return contigs
def _predict_proteins(contigs: list[tuple[str, str]]) -> list[str]:
import pyrodigal
encoded = [(name, seq.encode("ascii")) for name, seq in contigs]
total_nt = sum(len(s) for _, s in encoded)
if total_nt >= 20_000:
finder = pyrodigal.GeneFinder(meta=False)
try:
finder.train(b"TTAATTAATTAA".join(seq for _, seq in encoded))
except Exception:
finder = pyrodigal.GeneFinder(meta=True)
else:
finder = pyrodigal.GeneFinder(meta=True)
proteins: list[str] = []
for _, seq in encoded:
for gene in finder.find_genes(seq):
proteins.append(gene.translate().rstrip("*"))
return proteins
@app.cls(
cpu=4, # 4 vCPU per container; pyhmmer is multi-threaded
memory=4096, # 4 GB RAM (fits 700 MB HMM + working set)
timeout=3600 * 4,
secrets=[modal.Secret.from_name("ncbi-key", required_keys=["NCBI_API_KEY"])],
max_containers=16,
scaledown_window=120,
)
class KofamScanner:
@modal.enter()
def setup(self):
import os
import pyhmmer
import pyhmmer.easel
import pyhmmer.plan7
self.pyhmmer = pyhmmer
self.alphabet = pyhmmer.easel.Alphabet.amino()
with pyhmmer.plan7.HMMFile("/root/kofam.hmm") as fh:
self.hmms = list(fh)
# Load per-KO bitscore thresholds
thresholds: dict[str, float] = {}
with open("/root/ko_thresholds.tsv") as fh:
next(fh) # header
for line in fh:
parts = line.rstrip("\n").split("\t")
if len(parts) < 2:
continue
ko = parts[0]
try:
thresholds[ko] = float(parts[1])
except (TypeError, ValueError):
thresholds[ko] = 0.0
self.thresholds = thresholds
self.ncbi_key = os.environ.get("NCBI_API_KEY")
print(f"[setup] loaded {len(self.hmms):,} KOfam HMMs, "
f"{len(self.thresholds):,} thresholds", flush=True)
def _scan(self, proteins: list[str]) -> set[str]:
seqs = []
for i, prot in enumerate(proteins):
if not prot:
continue
ts = self.pyhmmer.easel.TextSequence(name=f"p{i}".encode(), sequence=prot)
seqs.append(ts.digitize(self.alphabet))
found: set[str] = set()
if not seqs:
return found
for top_hits in self.pyhmmer.hmmer.hmmsearch(self.hmms, seqs, E=DEFAULT_EVALUE):
raw = top_hits.query.name
ko = raw.decode() if isinstance(raw, bytes) else raw
thr = self.thresholds.get(ko, 0.0)
for hit in top_hits:
if hit.score >= thr and hit.evalue <= DEFAULT_EVALUE:
found.add(ko)
break
return found
@modal.method()
def scan_genome(self, accession: str) -> dict | None:
try:
contigs = _fetch_fasta_bytes(accession, self.ncbi_key)
if not contigs:
return None
proteins = _predict_proteins(contigs)
if not proteins:
return None
ko_hits = self._scan(proteins)
return {"genome_accession": accession, "ko_hits": sorted(ko_hits)}
except Exception as exc:
print(f" skip {accession}: {type(exc).__name__}: {exc}", flush=True)
return None
@app.local_entrypoint()
def main(out_path: str = "data/kofam_hits.jsonl", limit: int = 0):
import pandas as pd
feats = pd.read_parquet("data/features.parquet")
unique_accs = feats["genome_accession"].dropna().astype(str).unique().tolist()
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
done: set[str] = set()
if out.exists():
with open(out) as fh:
for line in fh:
try:
rec = json.loads(line)
except Exception:
continue
key = rec.get("genome_accession") or rec.get("accession")
if key:
done.add(str(key))
pending = [a for a in unique_accs if a not in done]
if limit:
pending = pending[:limit]
print(f"KOfam scan on Modal: {len(pending):,} pending ({len(done):,} cached)")
if not pending:
return
scanner = KofamScanner()
n_ok = n_fail = 0
with open(out, "a") as log:
for result in scanner.scan_genome.map(pending, return_exceptions=True):
if isinstance(result, Exception) or result is None:
n_fail += 1
continue
log.write(json.dumps(result) + "\n")
log.flush()
n_ok += 1
if n_ok % 200 == 0:
print(f" {n_ok:,} ok / {n_fail:,} fail")
print(f"\nFinished. {n_ok:,} succeeded, {n_fail:,} failed.")