Spaces:
Running
Running
File size: 5,579 Bytes
8c28a61 0ed74db 8c28a61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Extract per-genome ESM-2 embeddings for the full BacDive training corpus.
Designed to run on a CUDA GPU (Lightning AI T4 / A100). Falls back to MPS / CPU
for testing — but at scale you really want GPU.
Reads:
data/bacdive_phenotypes.parquet (strain list)
+ downloads each genome via the existing pipeline._fetch_fasta_bytes path
Writes:
data/embeddings.jsonl (one row per genome, append-only, resumable)
data/embeddings.parquet (materialized at end)
Resumability: re-running picks up where it left off (same JSONL pattern as
features.jsonl).
Usage:
# Full corpus on GPU (Lightning AI). ~3-5 hr on T4 with sample_n=50.
uv run --extra embeddings python scripts/11_extract_embeddings.py \\
--model facebook/esm2_t30_150M_UR50D --sample-n 50 --batch-size 32
# Smoke test on Mac MPS with smallest model
uv run --extra embeddings python scripts/11_extract_embeddings.py \\
--model facebook/esm2_t6_8M_UR50D --sample-n 20 --max 10
"""
from __future__ import annotations
import argparse
import json
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from microbe_model import config
from microbe_model.features.embeddings import embed_genome, load_esm2
from microbe_model.features.genome import predict_genes
from microbe_model.pipeline import _fetch_fasta_bytes
def _load_done_ids(path) -> set[int]:
if not path.exists():
return set()
ids: set[int] = set()
with open(path) as fh:
for line in fh:
try:
row = json.loads(line)
ids.add(int(row["bacdive_id"]))
except (json.JSONDecodeError, KeyError, ValueError):
continue
return ids
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="facebook/esm2_t30_150M_UR50D",
help="HF model id (esm2_t6_8M / t12_35M / t30_150M / t33_650M)")
parser.add_argument("--sample-n", type=int, default=50,
help="Proteins per genome to embed (None = all). 50 is a fast default.")
parser.add_argument("--batch-size", type=int, default=32,
help="ESM-2 batch size (32 fits on T4 16GB; raise on A100).")
parser.add_argument("--max", type=int, default=None,
help="Cap how many genomes to process (default: all training-ready).")
args = parser.parse_args()
pheno_path = config.DATA / "bacdive_phenotypes.parquet"
if not pheno_path.exists():
raise SystemExit(f"Missing {pheno_path}. Run scripts/01_fetch_bacdive.py first.")
pheno = pd.read_parquet(pheno_path)
has_genome = pheno["genome_accession"].notna()
label_cols = list(config.PHENOTYPE_TARGETS.keys())
has_label = pheno[label_cols].notna().any(axis=1)
ready = pheno[has_genome & has_label].copy()
out_path = config.DATA / "embeddings.jsonl"
done_ids = _load_done_ids(out_path)
pending = ready[~ready["bacdive_id"].astype(int).isin(done_ids)]
if args.max:
pending = pending.head(args.max)
print(f"Embedding {len(pending):,} genomes (skipping {len(done_ids):,} already done)")
print(f"Loading {args.model}...")
tokenizer, model, device = load_esm2(args.model)
print(f" device={device}, embed_dim={model.config.hidden_size}, "
f"sample_n={args.sample_n}, batch_size={args.batch_size}")
rng = np.random.default_rng(0)
t0 = time.time()
n_success = 0
n_fail = 0
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "a") as log:
for _, row in tqdm(pending.iterrows(), total=len(pending), desc="embed", unit="genome"):
bid = int(row["bacdive_id"])
acc = str(row["genome_accession"])
try:
contigs = _fetch_fasta_bytes(acc)
if not contigs:
n_fail += 1
continue
proteins, _, _ = predict_genes(contigs)
if not proteins:
n_fail += 1
continue
vec = embed_genome(
proteins, tokenizer, model, device,
sample_n=args.sample_n, batch_size=args.batch_size, rng=rng,
)
except Exception as exc: # noqa: BLE001 — single bad genome shouldn't kill the run
print(f" skip {acc}: {type(exc).__name__}: {exc}")
n_fail += 1
continue
payload = {
"bacdive_id": bid,
"genome_accession": acc,
"embed_dim": int(len(vec)),
"embedding": vec.tolist(),
}
log.write(json.dumps(payload) + "\n")
log.flush()
n_success += 1
elapsed = time.time() - t0
print(f"\nFinished in {elapsed/60:.1f} min. {n_success} succeeded, {n_fail} failed.")
# Materialize parquet — flatten the embedding list into per-dim columns
print("Materializing parquet...")
rows = []
with open(out_path) as fh:
for line in fh:
row = json.loads(line)
emb = row["embedding"]
d = {"bacdive_id": row["bacdive_id"], "genome_accession": row["genome_accession"]}
d.update({f"emb_{i}": float(v) for i, v in enumerate(emb)})
rows.append(d)
df = pd.DataFrame(rows)
parquet_path = config.DATA / "embeddings.parquet"
df.to_parquet(parquet_path, index=False)
print(f"Wrote {len(df):,} embeddings to {parquet_path}")
if __name__ == "__main__":
main()
|