microbe-model / scripts /11_extract_embeddings.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Extract per-genome ESM-2 embeddings for the full BacDive training corpus.
Designed to run on a CUDA GPU (Lightning AI T4 / A100). Falls back to MPS / CPU
for testing — but at scale you really want GPU.
Reads:
data/bacdive_phenotypes.parquet (strain list)
+ downloads each genome via the existing pipeline._fetch_fasta_bytes path
Writes:
data/embeddings.jsonl (one row per genome, append-only, resumable)
data/embeddings.parquet (materialized at end)
Resumability: re-running picks up where it left off (same JSONL pattern as
features.jsonl).
Usage:
# Full corpus on GPU (Lightning AI). ~3-5 hr on T4 with sample_n=50.
uv run --extra embeddings python scripts/11_extract_embeddings.py \\
--model facebook/esm2_t30_150M_UR50D --sample-n 50 --batch-size 32
# Smoke test on Mac MPS with smallest model
uv run --extra embeddings python scripts/11_extract_embeddings.py \\
--model facebook/esm2_t6_8M_UR50D --sample-n 20 --max 10
"""
from __future__ import annotations
import argparse
import json
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from microbe_model import config
from microbe_model.features.embeddings import embed_genome, load_esm2
from microbe_model.features.genome import predict_genes
from microbe_model.pipeline import _fetch_fasta_bytes
def _load_done_ids(path) -> set[int]:
if not path.exists():
return set()
ids: set[int] = set()
with open(path) as fh:
for line in fh:
try:
row = json.loads(line)
ids.add(int(row["bacdive_id"]))
except (json.JSONDecodeError, KeyError, ValueError):
continue
return ids
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="facebook/esm2_t30_150M_UR50D",
help="HF model id (esm2_t6_8M / t12_35M / t30_150M / t33_650M)")
parser.add_argument("--sample-n", type=int, default=50,
help="Proteins per genome to embed (None = all). 50 is a fast default.")
parser.add_argument("--batch-size", type=int, default=32,
help="ESM-2 batch size (32 fits on T4 16GB; raise on A100).")
parser.add_argument("--max", type=int, default=None,
help="Cap how many genomes to process (default: all training-ready).")
args = parser.parse_args()
pheno_path = config.DATA / "bacdive_phenotypes.parquet"
if not pheno_path.exists():
raise SystemExit(f"Missing {pheno_path}. Run scripts/01_fetch_bacdive.py first.")
pheno = pd.read_parquet(pheno_path)
has_genome = pheno["genome_accession"].notna()
label_cols = list(config.PHENOTYPE_TARGETS.keys())
has_label = pheno[label_cols].notna().any(axis=1)
ready = pheno[has_genome & has_label].copy()
out_path = config.DATA / "embeddings.jsonl"
done_ids = _load_done_ids(out_path)
pending = ready[~ready["bacdive_id"].astype(int).isin(done_ids)]
if args.max:
pending = pending.head(args.max)
print(f"Embedding {len(pending):,} genomes (skipping {len(done_ids):,} already done)")
print(f"Loading {args.model}...")
tokenizer, model, device = load_esm2(args.model)
print(f" device={device}, embed_dim={model.config.hidden_size}, "
f"sample_n={args.sample_n}, batch_size={args.batch_size}")
rng = np.random.default_rng(0)
t0 = time.time()
n_success = 0
n_fail = 0
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "a") as log:
for _, row in tqdm(pending.iterrows(), total=len(pending), desc="embed", unit="genome"):
bid = int(row["bacdive_id"])
acc = str(row["genome_accession"])
try:
contigs = _fetch_fasta_bytes(acc)
if not contigs:
n_fail += 1
continue
proteins, _, _ = predict_genes(contigs)
if not proteins:
n_fail += 1
continue
vec = embed_genome(
proteins, tokenizer, model, device,
sample_n=args.sample_n, batch_size=args.batch_size, rng=rng,
)
except Exception as exc: # noqa: BLE001 — single bad genome shouldn't kill the run
print(f" skip {acc}: {type(exc).__name__}: {exc}")
n_fail += 1
continue
payload = {
"bacdive_id": bid,
"genome_accession": acc,
"embed_dim": int(len(vec)),
"embedding": vec.tolist(),
}
log.write(json.dumps(payload) + "\n")
log.flush()
n_success += 1
elapsed = time.time() - t0
print(f"\nFinished in {elapsed/60:.1f} min. {n_success} succeeded, {n_fail} failed.")
# Materialize parquet — flatten the embedding list into per-dim columns
print("Materializing parquet...")
rows = []
with open(out_path) as fh:
for line in fh:
row = json.loads(line)
emb = row["embedding"]
d = {"bacdive_id": row["bacdive_id"], "genome_accession": row["genome_accession"]}
d.update({f"emb_{i}": float(v) for i, v in enumerate(emb)})
rows.append(d)
df = pd.DataFrame(rows)
parquet_path = config.DATA / "embeddings.parquet"
df.to_parquet(parquet_path, index=False)
print(f"Wrote {len(df):,} embeddings to {parquet_path}")
if __name__ == "__main__":
main()