File size: 4,175 Bytes
f0f1d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Deduplicated featurize for species-resolved genomes.

When many BacDive strains share a single species-level representative genome (the
common case after scripts/18), naively running scripts/02 re-downloads + re-runs
pyrodigal on the same FASTA per-strain. This script downloads each unique accession
once, then replicates the resulting feature dict across all bacdive_ids that share it.

Resumable via data/features.jsonl (skips bacdive_ids already in the log).

Usage:
    uv run python scripts/19_featurize_resolved.py --workers 7
"""
from __future__ import annotations

import argparse
import json
import os
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import pandas as pd
from tqdm import tqdm

from microbe_model import config
from microbe_model.pipeline import _load_done_ids, _process_one


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
    parser.add_argument("--max-accessions", type=int, default=None,
                        help="Cap how many unique accessions to process (debug).")
    args = parser.parse_args()

    pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")

    # Training-ready pool: any phenotype label + a genome accession
    label_cols = list(config.PHENOTYPE_TARGETS.keys())
    has_label = pheno[label_cols].notna().any(axis=1)
    has_genome = pheno["genome_accession"].notna()
    ready = pheno[has_label & has_genome].copy()
    ready["bacdive_id"] = ready["bacdive_id"].astype(int)
    ready["genome_accession"] = ready["genome_accession"].astype(str)

    out_path = config.DATA / "features.jsonl"
    done_ids = _load_done_ids(out_path)
    todo = ready[~ready["bacdive_id"].isin(done_ids)]
    print(f"strains in pool: {len(ready):,}")
    print(f"  already featurized: {len(done_ids):,}")
    print(f"  remaining: {len(todo):,}")

    # Group remaining strains by accession
    by_acc = todo.groupby("genome_accession")["bacdive_id"].apply(list).to_dict()
    accessions = sorted(by_acc.keys())
    if args.max_accessions:
        accessions = accessions[: args.max_accessions]
    print(f"unique accessions to download: {len(accessions):,}")
    print(f"  avg strains per accession: {sum(len(by_acc[a]) for a in accessions) / max(1, len(accessions)):.1f}")
    print(f"workers: {args.workers}\n")

    # Featurize each accession once; the worker tags the result with the *first* bacdive_id
    # of that accession's strain group. We then replicate the feature dict to all sibling
    # bacdive_ids before writing.
    rep_tasks = [(by_acc[acc][0], acc) for acc in accessions]

    n_success = 0
    n_replicated_rows = 0
    start = time.time()
    with open(out_path, "a") as fh, \
         ProcessPoolExecutor(max_workers=args.workers) as pool, \
         tqdm(total=len(rep_tasks), desc="featurize", unit="genome") as bar:
        futures = {pool.submit(_process_one, t): t for t in rep_tasks}
        for fut in as_completed(futures):
            rep_id, acc = futures[fut]
            bar.update(1)
            try:
                feats = fut.result()
            except Exception:
                feats = None
            if not feats:
                continue
            n_success += 1
            for bid in by_acc[acc]:
                row = dict(feats)
                row["bacdive_id"] = bid
                row["genome_accession"] = acc
                fh.write(json.dumps(row) + "\n")
                n_replicated_rows += 1
            fh.flush()
            bar.set_postfix(genomes_ok=n_success, rows=n_replicated_rows)

    print(f"\nfinished in {(time.time() - start) / 60:.1f} min")
    print(f"  unique genomes featurized: {n_success:,}/{len(rep_tasks):,}")
    print(f"  feature rows written: {n_replicated_rows:,}")

    # Materialize parquet
    df = pd.read_json(out_path, lines=True)
    df = df.drop_duplicates(subset=["bacdive_id"], keep="last")
    parquet = config.DATA / "features.parquet"
    df.to_parquet(parquet, index=False)
    print(f"  wrote {len(df):,} rows to {parquet}")


if __name__ == "__main__":
    main()