Spaces:
Running
Running
| """Modal app — extract ESM-2 embeddings for the full BacDive corpus. | |
| Each Modal container loads ESM-2 once on its GPU, then processes a stream of | |
| (bacdive_id, accession) tasks. The local entrypoint dispatches all training-ready | |
| strains via Modal's parallel .map(), and streams results back to local | |
| data/embeddings.jsonl as they complete (resumable: re-running skips finished IDs). | |
| Usage: | |
| # one-time: | |
| modal setup # OAuth Modal account | |
| modal secret create ncbi-key NCBI_API_KEY=... # paste your NCBI key | |
| # run: | |
| modal run scripts/modal_embed.py | |
| # or with custom flags: | |
| modal run scripts/modal_embed.py --gpu A10G --sample-n 50 --workers 16 | |
| Cost (as of 2026, A10G at ~$1/hr): | |
| 22K genomes × ~1 sec/genome on A10G ÷ 16 parallel containers ≈ 25 min wall time | |
| ≈ $7–10 total | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| import modal | |
| # --- Modal image ------------------------------------------------------------ | |
| # Pin Python and bundle the deps that genome → proteins → ESM-2 needs. | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| .pip_install([ | |
| "torch>=2.2", | |
| "transformers>=4.40", | |
| "accelerate>=0.30", | |
| "pyrodigal>=3.5", | |
| "biopython>=1.83", | |
| "requests>=2.32", | |
| "numpy>=1.26", | |
| ]) | |
| ) | |
| app = modal.App("microbe-esm2", image=image) | |
| DEFAULT_MODEL = "facebook/esm2_t30_150M_UR50D" | |
| DATASETS_URL = "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{acc}/download" | |
| VERSION_FALLBACKS = (".1", ".2", ".3", ".4") | |
| EMPTY_ZIP_BYTES = 2_000 | |
| # --- Self-contained helpers (run inside the container) ---------------------- | |
| def _has_version(accession: str) -> bool: | |
| if "." not in accession: | |
| return False | |
| return accession.rsplit(".", 1)[-1].isdigit() | |
| def _candidate_accessions(accession: str) -> list[str]: | |
| if _has_version(accession): | |
| return [accession] | |
| return [accession + v for v in VERSION_FALLBACKS] | |
| def _fetch_fasta_bytes(accession: str, ncbi_key: str | None) -> list[tuple[str, str]] | None: | |
| import io | |
| import time | |
| import zipfile | |
| import requests | |
| rate = 0.1 if ncbi_key else 0.34 | |
| headers: dict[str, str] = {"Accept": "application/zip"} | |
| if ncbi_key: | |
| headers["api-key"] = ncbi_key | |
| params = {"include_annotation_type": "GENOME_FASTA"} | |
| for candidate in _candidate_accessions(accession): | |
| zip_bytes: bytes | None = None | |
| for attempt in range(3): | |
| try: | |
| time.sleep(rate) | |
| resp = requests.get( | |
| DATASETS_URL.format(acc=candidate), | |
| params=params, headers=headers, timeout=120, | |
| ) | |
| if resp.status_code == 404: | |
| break | |
| if resp.status_code in (429, 502, 503): | |
| time.sleep(2 ** attempt) | |
| continue | |
| resp.raise_for_status() | |
| except requests.RequestException: | |
| if attempt == 2: | |
| break | |
| time.sleep(2 ** attempt) | |
| continue | |
| if len(resp.content) < EMPTY_ZIP_BYTES: | |
| break | |
| zip_bytes = resp.content | |
| break | |
| if zip_bytes is None: | |
| continue | |
| try: | |
| with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: | |
| fasta_names = [n for n in zf.namelist() if n.endswith(".fna")] | |
| if not fasta_names: | |
| continue | |
| with zf.open(fasta_names[0]) as src: | |
| raw = src.read() | |
| except zipfile.BadZipFile: | |
| continue | |
| return _parse_fasta(raw) | |
| return None | |
| def _parse_fasta(raw: bytes) -> list[tuple[str, str]]: | |
| contigs: list[tuple[str, str]] = [] | |
| current_id: str | None = None | |
| chunks: list[str] = [] | |
| for line in raw.splitlines(): | |
| if not line: | |
| continue | |
| if line.startswith(b">"): | |
| if current_id is not None: | |
| contigs.append((current_id, "".join(chunks).upper())) | |
| current_id = line[1:].decode("ascii", errors="replace").split()[0] | |
| chunks = [] | |
| else: | |
| chunks.append(line.decode("ascii", errors="replace")) | |
| if current_id is not None: | |
| contigs.append((current_id, "".join(chunks).upper())) | |
| return contigs | |
| def _predict_proteins(contigs: list[tuple[str, str]]) -> list[str]: | |
| import pyrodigal | |
| encoded = [(name, seq.encode("ascii")) for name, seq in contigs] | |
| total_nt = sum(len(s) for _, s in encoded) | |
| if total_nt >= 20_000: | |
| finder = pyrodigal.GeneFinder(meta=False) | |
| train_seq = b"TTAATTAATTAA".join(seq for _, seq in encoded) | |
| try: | |
| finder.train(train_seq) | |
| except Exception: | |
| finder = pyrodigal.GeneFinder(meta=True) | |
| else: | |
| finder = pyrodigal.GeneFinder(meta=True) | |
| proteins: list[str] = [] | |
| for _, seq in encoded: | |
| for gene in finder.find_genes(seq): | |
| proteins.append(gene.translate().rstrip("*")) | |
| return proteins | |
| # --- Modal class: loads ESM-2 once per container, batches embeddings -------- | |
| class Embedder: | |
| def setup(self): | |
| import os | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| self.np = np | |
| self.torch = torch | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 | |
| # Read configurable knobs from env (set by the local entrypoint) | |
| self.model_name = os.environ.get("ESM2_MODEL", DEFAULT_MODEL) | |
| self.sample_n = int(os.environ.get("ESM2_SAMPLE_N", "50")) | |
| self.batch_size = int(os.environ.get("ESM2_BATCH_SIZE", "16")) | |
| print(f"[setup] loading {self.model_name} on {self.device}", flush=True) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModel.from_pretrained(self.model_name, dtype=self.dtype) | |
| self.model.to(self.device) | |
| self.model.train(False) | |
| self.embed_dim = self.model.config.hidden_size | |
| self.ncbi_key = os.environ.get("NCBI_API_KEY") | |
| self.rng = np.random.default_rng(0) | |
| print(f"[setup] embed_dim={self.embed_dim}, " | |
| f"sample_n={self.sample_n}, batch={self.batch_size}, ready", flush=True) | |
| def _embed_proteins(self, proteins: list[str]): | |
| import torch | |
| if not proteins: | |
| return self.np.zeros((0, self.embed_dim), dtype=self.np.float32) | |
| out: list = [] | |
| for i in range(0, len(proteins), self.batch_size): | |
| batch = proteins[i : i + self.batch_size] | |
| enc = self.tokenizer( | |
| batch, return_tensors="pt", padding=True, truncation=True, max_length=1024, | |
| ) | |
| enc = {k: v.to(self.device) for k, v in enc.items()} | |
| with torch.inference_mode(): | |
| outputs = self.model(**enc) | |
| last_hidden = outputs.last_hidden_state | |
| mask = enc["attention_mask"].unsqueeze(-1).to(last_hidden.dtype) | |
| pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | |
| out.append(pooled.float().cpu().numpy()) | |
| return self.np.concatenate(out, axis=0) | |
| def embed_genome(self, bacdive_id: int, accession: str) -> dict | None: | |
| try: | |
| contigs = _fetch_fasta_bytes(accession, self.ncbi_key) | |
| if not contigs: | |
| return None | |
| proteins = _predict_proteins(contigs) | |
| if not proteins: | |
| return None | |
| if self.sample_n is not None and self.sample_n < len(proteins): | |
| idx = self.rng.choice(len(proteins), size=self.sample_n, replace=False) | |
| proteins = [proteins[i] for i in idx] | |
| matrix = self._embed_proteins(proteins) | |
| vec = matrix.mean(axis=0).astype(self.np.float32) | |
| return { | |
| "bacdive_id": int(bacdive_id), | |
| "genome_accession": accession, | |
| "embed_dim": int(len(vec)), | |
| "embedding": vec.tolist(), | |
| } | |
| except Exception as exc: | |
| print(f" skip {accession}: {type(exc).__name__}: {exc}", flush=True) | |
| return None | |
| # --- Local entrypoint ------------------------------------------------------- | |
| def main( | |
| model: str = DEFAULT_MODEL, | |
| sample_n: int = 50, | |
| batch_size: int = 16, | |
| gpu: str = "A10G", | |
| out_path: str = "data/embeddings.jsonl", | |
| limit: int = 0, | |
| ): | |
| """Dispatch all training-ready genomes to Modal and stream results to disk.""" | |
| import pandas as pd | |
| pheno = pd.read_parquet("data/bacdive_phenotypes.parquet") | |
| has_genome = pheno["genome_accession"].notna() | |
| label_cols = ["optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"] | |
| has_label = pheno[label_cols].notna().any(axis=1) | |
| ready = pheno[has_genome & has_label].copy() | |
| ready["bacdive_id"] = ready["bacdive_id"].astype(int) | |
| out = Path(out_path) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| done: set[int] = set() | |
| if out.exists(): | |
| with open(out) as fh: | |
| for line in fh: | |
| try: | |
| done.add(int(json.loads(line)["bacdive_id"])) | |
| except Exception: | |
| continue | |
| pending = ready[~ready["bacdive_id"].isin(done)] | |
| if limit: | |
| pending = pending.head(limit) | |
| tasks = list(zip(pending["bacdive_id"], pending["genome_accession"].astype(str), strict=True)) | |
| print(f"Embedding {len(tasks):,} genomes (skipping {len(done):,} cached)") | |
| print(f"Model: {model} sample_n={sample_n} batch={batch_size} gpu={gpu}") | |
| if not tasks: | |
| print("Nothing to do.") | |
| return | |
| config_secret = modal.Secret.from_dict({ | |
| "ESM2_MODEL": model, | |
| "ESM2_SAMPLE_N": str(sample_n), | |
| "ESM2_BATCH_SIZE": str(batch_size), | |
| }) | |
| embedder = Embedder.with_options( | |
| gpu=gpu, | |
| secrets=[ | |
| modal.Secret.from_name("ncbi-key", required_keys=["NCBI_API_KEY"]), | |
| config_secret, | |
| ], | |
| )() | |
| n_ok = 0 | |
| n_fail = 0 | |
| with open(out, "a") as log: | |
| for result in embedder.embed_genome.starmap(tasks, return_exceptions=True): | |
| if isinstance(result, Exception): | |
| n_fail += 1 | |
| continue | |
| if result is None: | |
| n_fail += 1 | |
| continue | |
| log.write(json.dumps(result) + "\n") | |
| log.flush() | |
| n_ok += 1 | |
| if n_ok % 100 == 0: | |
| print(f" {n_ok:,} ok / {n_fail:,} fail") | |
| print(f"\nFinished. {n_ok:,} succeeded, {n_fail:,} failed.") | |
| print(f"Streamed to {out}") | |
| print("Run scripts/_materialize_embeddings.py (or the snippet at the bottom of " | |
| "scripts/11_extract_embeddings.py) to build the parquet from this JSONL.") | |