microbe-model / scripts /modal_embed.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Modal app — extract ESM-2 embeddings for the full BacDive corpus.
Each Modal container loads ESM-2 once on its GPU, then processes a stream of
(bacdive_id, accession) tasks. The local entrypoint dispatches all training-ready
strains via Modal's parallel .map(), and streams results back to local
data/embeddings.jsonl as they complete (resumable: re-running skips finished IDs).
Usage:
# one-time:
modal setup # OAuth Modal account
modal secret create ncbi-key NCBI_API_KEY=... # paste your NCBI key
# run:
modal run scripts/modal_embed.py
# or with custom flags:
modal run scripts/modal_embed.py --gpu A10G --sample-n 50 --workers 16
Cost (as of 2026, A10G at ~$1/hr):
22K genomes × ~1 sec/genome on A10G ÷ 16 parallel containers ≈ 25 min wall time
≈ $7–10 total
"""
from __future__ import annotations
import json
from pathlib import Path
import modal
# --- Modal image ------------------------------------------------------------
# Pin Python and bundle the deps that genome → proteins → ESM-2 needs.
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install([
"torch>=2.2",
"transformers>=4.40",
"accelerate>=0.30",
"pyrodigal>=3.5",
"biopython>=1.83",
"requests>=2.32",
"numpy>=1.26",
])
)
app = modal.App("microbe-esm2", image=image)
DEFAULT_MODEL = "facebook/esm2_t30_150M_UR50D"
DATASETS_URL = "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{acc}/download"
VERSION_FALLBACKS = (".1", ".2", ".3", ".4")
EMPTY_ZIP_BYTES = 2_000
# --- Self-contained helpers (run inside the container) ----------------------
def _has_version(accession: str) -> bool:
if "." not in accession:
return False
return accession.rsplit(".", 1)[-1].isdigit()
def _candidate_accessions(accession: str) -> list[str]:
if _has_version(accession):
return [accession]
return [accession + v for v in VERSION_FALLBACKS]
def _fetch_fasta_bytes(accession: str, ncbi_key: str | None) -> list[tuple[str, str]] | None:
import io
import time
import zipfile
import requests
rate = 0.1 if ncbi_key else 0.34
headers: dict[str, str] = {"Accept": "application/zip"}
if ncbi_key:
headers["api-key"] = ncbi_key
params = {"include_annotation_type": "GENOME_FASTA"}
for candidate in _candidate_accessions(accession):
zip_bytes: bytes | None = None
for attempt in range(3):
try:
time.sleep(rate)
resp = requests.get(
DATASETS_URL.format(acc=candidate),
params=params, headers=headers, timeout=120,
)
if resp.status_code == 404:
break
if resp.status_code in (429, 502, 503):
time.sleep(2 ** attempt)
continue
resp.raise_for_status()
except requests.RequestException:
if attempt == 2:
break
time.sleep(2 ** attempt)
continue
if len(resp.content) < EMPTY_ZIP_BYTES:
break
zip_bytes = resp.content
break
if zip_bytes is None:
continue
try:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
fasta_names = [n for n in zf.namelist() if n.endswith(".fna")]
if not fasta_names:
continue
with zf.open(fasta_names[0]) as src:
raw = src.read()
except zipfile.BadZipFile:
continue
return _parse_fasta(raw)
return None
def _parse_fasta(raw: bytes) -> list[tuple[str, str]]:
contigs: list[tuple[str, str]] = []
current_id: str | None = None
chunks: list[str] = []
for line in raw.splitlines():
if not line:
continue
if line.startswith(b">"):
if current_id is not None:
contigs.append((current_id, "".join(chunks).upper()))
current_id = line[1:].decode("ascii", errors="replace").split()[0]
chunks = []
else:
chunks.append(line.decode("ascii", errors="replace"))
if current_id is not None:
contigs.append((current_id, "".join(chunks).upper()))
return contigs
def _predict_proteins(contigs: list[tuple[str, str]]) -> list[str]:
import pyrodigal
encoded = [(name, seq.encode("ascii")) for name, seq in contigs]
total_nt = sum(len(s) for _, s in encoded)
if total_nt >= 20_000:
finder = pyrodigal.GeneFinder(meta=False)
train_seq = b"TTAATTAATTAA".join(seq for _, seq in encoded)
try:
finder.train(train_seq)
except Exception:
finder = pyrodigal.GeneFinder(meta=True)
else:
finder = pyrodigal.GeneFinder(meta=True)
proteins: list[str] = []
for _, seq in encoded:
for gene in finder.find_genes(seq):
proteins.append(gene.translate().rstrip("*"))
return proteins
# --- Modal class: loads ESM-2 once per container, batches embeddings --------
@app.cls(
gpu="A10G",
timeout=3600 * 4,
secrets=[modal.Secret.from_name("ncbi-key", required_keys=["NCBI_API_KEY"])],
max_containers=16,
scaledown_window=60,
)
class Embedder:
@modal.enter()
def setup(self):
import os
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
self.np = np
self.torch = torch
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
# Read configurable knobs from env (set by the local entrypoint)
self.model_name = os.environ.get("ESM2_MODEL", DEFAULT_MODEL)
self.sample_n = int(os.environ.get("ESM2_SAMPLE_N", "50"))
self.batch_size = int(os.environ.get("ESM2_BATCH_SIZE", "16"))
print(f"[setup] loading {self.model_name} on {self.device}", flush=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name, dtype=self.dtype)
self.model.to(self.device)
self.model.train(False)
self.embed_dim = self.model.config.hidden_size
self.ncbi_key = os.environ.get("NCBI_API_KEY")
self.rng = np.random.default_rng(0)
print(f"[setup] embed_dim={self.embed_dim}, "
f"sample_n={self.sample_n}, batch={self.batch_size}, ready", flush=True)
def _embed_proteins(self, proteins: list[str]):
import torch
if not proteins:
return self.np.zeros((0, self.embed_dim), dtype=self.np.float32)
out: list = []
for i in range(0, len(proteins), self.batch_size):
batch = proteins[i : i + self.batch_size]
enc = self.tokenizer(
batch, return_tensors="pt", padding=True, truncation=True, max_length=1024,
)
enc = {k: v.to(self.device) for k, v in enc.items()}
with torch.inference_mode():
outputs = self.model(**enc)
last_hidden = outputs.last_hidden_state
mask = enc["attention_mask"].unsqueeze(-1).to(last_hidden.dtype)
pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
out.append(pooled.float().cpu().numpy())
return self.np.concatenate(out, axis=0)
@modal.method()
def embed_genome(self, bacdive_id: int, accession: str) -> dict | None:
try:
contigs = _fetch_fasta_bytes(accession, self.ncbi_key)
if not contigs:
return None
proteins = _predict_proteins(contigs)
if not proteins:
return None
if self.sample_n is not None and self.sample_n < len(proteins):
idx = self.rng.choice(len(proteins), size=self.sample_n, replace=False)
proteins = [proteins[i] for i in idx]
matrix = self._embed_proteins(proteins)
vec = matrix.mean(axis=0).astype(self.np.float32)
return {
"bacdive_id": int(bacdive_id),
"genome_accession": accession,
"embed_dim": int(len(vec)),
"embedding": vec.tolist(),
}
except Exception as exc:
print(f" skip {accession}: {type(exc).__name__}: {exc}", flush=True)
return None
# --- Local entrypoint -------------------------------------------------------
@app.local_entrypoint()
def main(
model: str = DEFAULT_MODEL,
sample_n: int = 50,
batch_size: int = 16,
gpu: str = "A10G",
out_path: str = "data/embeddings.jsonl",
limit: int = 0,
):
"""Dispatch all training-ready genomes to Modal and stream results to disk."""
import pandas as pd
pheno = pd.read_parquet("data/bacdive_phenotypes.parquet")
has_genome = pheno["genome_accession"].notna()
label_cols = ["optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"]
has_label = pheno[label_cols].notna().any(axis=1)
ready = pheno[has_genome & has_label].copy()
ready["bacdive_id"] = ready["bacdive_id"].astype(int)
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
done: set[int] = set()
if out.exists():
with open(out) as fh:
for line in fh:
try:
done.add(int(json.loads(line)["bacdive_id"]))
except Exception:
continue
pending = ready[~ready["bacdive_id"].isin(done)]
if limit:
pending = pending.head(limit)
tasks = list(zip(pending["bacdive_id"], pending["genome_accession"].astype(str), strict=True))
print(f"Embedding {len(tasks):,} genomes (skipping {len(done):,} cached)")
print(f"Model: {model} sample_n={sample_n} batch={batch_size} gpu={gpu}")
if not tasks:
print("Nothing to do.")
return
config_secret = modal.Secret.from_dict({
"ESM2_MODEL": model,
"ESM2_SAMPLE_N": str(sample_n),
"ESM2_BATCH_SIZE": str(batch_size),
})
embedder = Embedder.with_options(
gpu=gpu,
secrets=[
modal.Secret.from_name("ncbi-key", required_keys=["NCBI_API_KEY"]),
config_secret,
],
)()
n_ok = 0
n_fail = 0
with open(out, "a") as log:
for result in embedder.embed_genome.starmap(tasks, return_exceptions=True):
if isinstance(result, Exception):
n_fail += 1
continue
if result is None:
n_fail += 1
continue
log.write(json.dumps(result) + "\n")
log.flush()
n_ok += 1
if n_ok % 100 == 0:
print(f" {n_ok:,} ok / {n_fail:,} fail")
print(f"\nFinished. {n_ok:,} succeeded, {n_fail:,} failed.")
print(f"Streamed to {out}")
print("Run scripts/_materialize_embeddings.py (or the snippet at the bottom of "
"scripts/11_extract_embeddings.py) to build the parquet from this JSONL.")