File size: 5,579 Bytes
8c28a61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ed74db
 
8c28a61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Extract per-genome ESM-2 embeddings for the full BacDive training corpus.

Designed to run on a CUDA GPU (Lightning AI T4 / A100). Falls back to MPS / CPU
for testing — but at scale you really want GPU.

Reads:
  data/bacdive_phenotypes.parquet (strain list)
  + downloads each genome via the existing pipeline._fetch_fasta_bytes path

Writes:
  data/embeddings.jsonl  (one row per genome, append-only, resumable)
  data/embeddings.parquet (materialized at end)

Resumability: re-running picks up where it left off (same JSONL pattern as
features.jsonl).

Usage:
    # Full corpus on GPU (Lightning AI). ~3-5 hr on T4 with sample_n=50.
    uv run --extra embeddings python scripts/11_extract_embeddings.py \\
        --model facebook/esm2_t30_150M_UR50D --sample-n 50 --batch-size 32

    # Smoke test on Mac MPS with smallest model
    uv run --extra embeddings python scripts/11_extract_embeddings.py \\
        --model facebook/esm2_t6_8M_UR50D --sample-n 20 --max 10
"""
from __future__ import annotations

import argparse
import json
import time

import numpy as np
import pandas as pd
from tqdm import tqdm

from microbe_model import config
from microbe_model.features.embeddings import embed_genome, load_esm2
from microbe_model.features.genome import predict_genes
from microbe_model.pipeline import _fetch_fasta_bytes


def _load_done_ids(path) -> set[int]:
    if not path.exists():
        return set()
    ids: set[int] = set()
    with open(path) as fh:
        for line in fh:
            try:
                row = json.loads(line)
                ids.add(int(row["bacdive_id"]))
            except (json.JSONDecodeError, KeyError, ValueError):
                continue
    return ids


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="facebook/esm2_t30_150M_UR50D",
                        help="HF model id (esm2_t6_8M / t12_35M / t30_150M / t33_650M)")
    parser.add_argument("--sample-n", type=int, default=50,
                        help="Proteins per genome to embed (None = all). 50 is a fast default.")
    parser.add_argument("--batch-size", type=int, default=32,
                        help="ESM-2 batch size (32 fits on T4 16GB; raise on A100).")
    parser.add_argument("--max", type=int, default=None,
                        help="Cap how many genomes to process (default: all training-ready).")
    args = parser.parse_args()

    pheno_path = config.DATA / "bacdive_phenotypes.parquet"
    if not pheno_path.exists():
        raise SystemExit(f"Missing {pheno_path}. Run scripts/01_fetch_bacdive.py first.")
    pheno = pd.read_parquet(pheno_path)

    has_genome = pheno["genome_accession"].notna()
    label_cols = list(config.PHENOTYPE_TARGETS.keys())
    has_label = pheno[label_cols].notna().any(axis=1)
    ready = pheno[has_genome & has_label].copy()

    out_path = config.DATA / "embeddings.jsonl"
    done_ids = _load_done_ids(out_path)
    pending = ready[~ready["bacdive_id"].astype(int).isin(done_ids)]
    if args.max:
        pending = pending.head(args.max)
    print(f"Embedding {len(pending):,} genomes (skipping {len(done_ids):,} already done)")

    print(f"Loading {args.model}...")
    tokenizer, model, device = load_esm2(args.model)
    print(f"  device={device}, embed_dim={model.config.hidden_size}, "
          f"sample_n={args.sample_n}, batch_size={args.batch_size}")

    rng = np.random.default_rng(0)
    t0 = time.time()
    n_success = 0
    n_fail = 0
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_path, "a") as log:
        for _, row in tqdm(pending.iterrows(), total=len(pending), desc="embed", unit="genome"):
            bid = int(row["bacdive_id"])
            acc = str(row["genome_accession"])
            try:
                contigs = _fetch_fasta_bytes(acc)
                if not contigs:
                    n_fail += 1
                    continue
                proteins, _, _ = predict_genes(contigs)
                if not proteins:
                    n_fail += 1
                    continue
                vec = embed_genome(
                    proteins, tokenizer, model, device,
                    sample_n=args.sample_n, batch_size=args.batch_size, rng=rng,
                )
            except Exception as exc:  # noqa: BLE001 — single bad genome shouldn't kill the run
                print(f"  skip {acc}: {type(exc).__name__}: {exc}")
                n_fail += 1
                continue
            payload = {
                "bacdive_id": bid,
                "genome_accession": acc,
                "embed_dim": int(len(vec)),
                "embedding": vec.tolist(),
            }
            log.write(json.dumps(payload) + "\n")
            log.flush()
            n_success += 1

    elapsed = time.time() - t0
    print(f"\nFinished in {elapsed/60:.1f} min. {n_success} succeeded, {n_fail} failed.")

    # Materialize parquet — flatten the embedding list into per-dim columns
    print("Materializing parquet...")
    rows = []
    with open(out_path) as fh:
        for line in fh:
            row = json.loads(line)
            emb = row["embedding"]
            d = {"bacdive_id": row["bacdive_id"], "genome_accession": row["genome_accession"]}
            d.update({f"emb_{i}": float(v) for i, v in enumerate(emb)})
            rows.append(d)
    df = pd.DataFrame(rows)
    parquet_path = config.DATA / "embeddings.parquet"
    df.to_parquet(parquet_path, index=False)
    print(f"Wrote {len(df):,} embeddings to {parquet_path}")


if __name__ == "__main__":
    main()