File size: 3,104 Bytes
383bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Streaming fetch + feature extraction.

Reads training-ready strains (those with both genome accession and ≥1 phenotype label)
from data/bacdive_phenotypes.parquet, then in parallel:
  - downloads each genome FASTA from NCBI Datasets into memory
  - runs pyrodigal + AA-composition feature extraction
  - appends the result to data/features.jsonl

Resumable — re-running picks up where it left off (skips bacdive_ids already in the log).
"""
from __future__ import annotations

import argparse
import os
import time

import pandas as pd
from tqdm import tqdm

from microbe_model import config
from microbe_model.pipeline import stream_fetch_and_featurize


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
    parser.add_argument("--max", type=int, default=None,
                        help="Cap how many strains to process (default: all training-ready).")
    parser.add_argument("--require-target", default="optimal_temperature_c",
                        help="Filter to strains with this label populated (or 'any' for any label).")
    args = parser.parse_args()

    pheno_path = config.DATA / "bacdive_phenotypes.parquet"
    if not pheno_path.exists():
        raise SystemExit(f"Missing {pheno_path}. Run scripts/01_fetch_bacdive.py first.")
    pheno = pd.read_parquet(pheno_path)

    has_genome = pheno["genome_accession"].notna()
    if args.require_target == "any":
        label_cols = list(config.PHENOTYPE_TARGETS.keys())
        has_label = pheno[label_cols].notna().any(axis=1)
    else:
        has_label = pheno[args.require_target].notna()

    ready = pheno[has_genome & has_label].copy()
    if args.max:
        ready = ready.head(args.max)

    tasks = list(zip(ready["bacdive_id"].astype(int), ready["genome_accession"].astype(str), strict=True))
    out_path = config.DATA / "features.jsonl"
    print(f"Training-ready strains: {len(tasks)}")
    print(f"Workers: {args.workers}")
    print(f"Output: {out_path}")
    print()

    start = time.time()
    with tqdm(total=len(tasks), desc="featurize", unit="strain") as bar:
        def progress(n_completed, n_success, n_total):
            bar.n = n_completed
            bar.set_postfix(success=n_success, fail=n_completed - n_success)
            bar.refresh()

        stream_fetch_and_featurize(
            tasks,
            out_path=out_path,
            n_workers=args.workers,
            on_progress=progress,
        )

    elapsed = time.time() - start
    if out_path.exists():
        with open(out_path) as fh:
            rows_in_log = sum(1 for _ in fh)
    else:
        rows_in_log = 0
    print(f"\nFinished in {elapsed/60:.1f} min. {rows_in_log} feature rows in {out_path.name}.")

    # Materialize parquet from JSONL for downstream training
    if rows_in_log:
        df = pd.read_json(out_path, lines=True)
        parquet = config.DATA / "features.parquet"
        df.to_parquet(parquet, index=False)
        print(f"Wrote {len(df)} rows to {parquet}")


if __name__ == "__main__":
    main()