Spaces:
Running
Running
| """Extract per-genome ESM-2 embeddings for the full BacDive training corpus. | |
| Designed to run on a CUDA GPU (Lightning AI T4 / A100). Falls back to MPS / CPU | |
| for testing — but at scale you really want GPU. | |
| Reads: | |
| data/bacdive_phenotypes.parquet (strain list) | |
| + downloads each genome via the existing pipeline._fetch_fasta_bytes path | |
| Writes: | |
| data/embeddings.jsonl (one row per genome, append-only, resumable) | |
| data/embeddings.parquet (materialized at end) | |
| Resumability: re-running picks up where it left off (same JSONL pattern as | |
| features.jsonl). | |
| Usage: | |
| # Full corpus on GPU (Lightning AI). ~3-5 hr on T4 with sample_n=50. | |
| uv run --extra embeddings python scripts/11_extract_embeddings.py \\ | |
| --model facebook/esm2_t30_150M_UR50D --sample-n 50 --batch-size 32 | |
| # Smoke test on Mac MPS with smallest model | |
| uv run --extra embeddings python scripts/11_extract_embeddings.py \\ | |
| --model facebook/esm2_t6_8M_UR50D --sample-n 20 --max 10 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from microbe_model import config | |
| from microbe_model.features.embeddings import embed_genome, load_esm2 | |
| from microbe_model.features.genome import predict_genes | |
| from microbe_model.pipeline import _fetch_fasta_bytes | |
| def _load_done_ids(path) -> set[int]: | |
| if not path.exists(): | |
| return set() | |
| ids: set[int] = set() | |
| with open(path) as fh: | |
| for line in fh: | |
| try: | |
| row = json.loads(line) | |
| ids.add(int(row["bacdive_id"])) | |
| except (json.JSONDecodeError, KeyError, ValueError): | |
| continue | |
| return ids | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="facebook/esm2_t30_150M_UR50D", | |
| help="HF model id (esm2_t6_8M / t12_35M / t30_150M / t33_650M)") | |
| parser.add_argument("--sample-n", type=int, default=50, | |
| help="Proteins per genome to embed (None = all). 50 is a fast default.") | |
| parser.add_argument("--batch-size", type=int, default=32, | |
| help="ESM-2 batch size (32 fits on T4 16GB; raise on A100).") | |
| parser.add_argument("--max", type=int, default=None, | |
| help="Cap how many genomes to process (default: all training-ready).") | |
| args = parser.parse_args() | |
| pheno_path = config.DATA / "bacdive_phenotypes.parquet" | |
| if not pheno_path.exists(): | |
| raise SystemExit(f"Missing {pheno_path}. Run scripts/01_fetch_bacdive.py first.") | |
| pheno = pd.read_parquet(pheno_path) | |
| has_genome = pheno["genome_accession"].notna() | |
| label_cols = list(config.PHENOTYPE_TARGETS.keys()) | |
| has_label = pheno[label_cols].notna().any(axis=1) | |
| ready = pheno[has_genome & has_label].copy() | |
| out_path = config.DATA / "embeddings.jsonl" | |
| done_ids = _load_done_ids(out_path) | |
| pending = ready[~ready["bacdive_id"].astype(int).isin(done_ids)] | |
| if args.max: | |
| pending = pending.head(args.max) | |
| print(f"Embedding {len(pending):,} genomes (skipping {len(done_ids):,} already done)") | |
| print(f"Loading {args.model}...") | |
| tokenizer, model, device = load_esm2(args.model) | |
| print(f" device={device}, embed_dim={model.config.hidden_size}, " | |
| f"sample_n={args.sample_n}, batch_size={args.batch_size}") | |
| rng = np.random.default_rng(0) | |
| t0 = time.time() | |
| n_success = 0 | |
| n_fail = 0 | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out_path, "a") as log: | |
| for _, row in tqdm(pending.iterrows(), total=len(pending), desc="embed", unit="genome"): | |
| bid = int(row["bacdive_id"]) | |
| acc = str(row["genome_accession"]) | |
| try: | |
| contigs = _fetch_fasta_bytes(acc) | |
| if not contigs: | |
| n_fail += 1 | |
| continue | |
| proteins, _, _ = predict_genes(contigs) | |
| if not proteins: | |
| n_fail += 1 | |
| continue | |
| vec = embed_genome( | |
| proteins, tokenizer, model, device, | |
| sample_n=args.sample_n, batch_size=args.batch_size, rng=rng, | |
| ) | |
| except Exception as exc: # noqa: BLE001 — single bad genome shouldn't kill the run | |
| print(f" skip {acc}: {type(exc).__name__}: {exc}") | |
| n_fail += 1 | |
| continue | |
| payload = { | |
| "bacdive_id": bid, | |
| "genome_accession": acc, | |
| "embed_dim": int(len(vec)), | |
| "embedding": vec.tolist(), | |
| } | |
| log.write(json.dumps(payload) + "\n") | |
| log.flush() | |
| n_success += 1 | |
| elapsed = time.time() - t0 | |
| print(f"\nFinished in {elapsed/60:.1f} min. {n_success} succeeded, {n_fail} failed.") | |
| # Materialize parquet — flatten the embedding list into per-dim columns | |
| print("Materializing parquet...") | |
| rows = [] | |
| with open(out_path) as fh: | |
| for line in fh: | |
| row = json.loads(line) | |
| emb = row["embedding"] | |
| d = {"bacdive_id": row["bacdive_id"], "genome_accession": row["genome_accession"]} | |
| d.update({f"emb_{i}": float(v) for i, v in enumerate(emb)}) | |
| rows.append(d) | |
| df = pd.DataFrame(rows) | |
| parquet_path = config.DATA / "embeddings.parquet" | |
| df.to_parquet(parquet_path, index=False) | |
| print(f"Wrote {len(df):,} embeddings to {parquet_path}") | |
| if __name__ == "__main__": | |
| main() | |