File size: 5,993 Bytes
30e65bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Apply the trained v1 model to uncultured (non-BacDive) GTDB representatives.

Pipeline:
  1. Read data/gtdb_candidates.parquet (created by scripts/06_fetch_gtdb_candidates.py)
  2. Stream-fetch + featurize each candidate (writes data/uncultured_features.jsonl)
  3. Re-train? No — load the existing baseline_results.json + reload an XGBoost model
     trained on the BacDive corpus.

The output: artifacts/uncultured_predictions.parquet — one row per genome with all
four phenotype targets predicted from genome features alone. This is the v0 research
deliverable: predicted cultivation conditions for genomes that have never been cultured.
"""
from __future__ import annotations

import argparse

import pandas as pd
import xgboost as xgb
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

from microbe_model import config
from microbe_model.pipeline import stream_fetch_and_featurize


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--max", type=int, default=None,
                        help="Cap how many candidates to process (default: all).")
    parser.add_argument("--workers", type=int, default=7)
    args = parser.parse_args()

    candidates_path = config.DATA / "gtdb_candidates.parquet"
    if not candidates_path.exists():
        raise SystemExit(f"Missing {candidates_path}. Run scripts/06_fetch_gtdb_candidates.py first.")

    candidates = pd.read_parquet(candidates_path)
    if args.max:
        candidates = candidates.head(args.max)
    print(f"Predicting on {len(candidates):,} GTDB representatives not in BacDive")

    # Use a synthetic int "fake_id" since these aren't BacDive strains
    # (the streaming pipeline expects int IDs but doesn't care what they are)
    candidates = candidates.reset_index(drop=True)
    candidates["fake_id"] = candidates.index + 1_000_000_000  # avoid collision with BacDive IDs

    tasks = list(zip(
        candidates["fake_id"].astype(int),
        candidates["genome_accession"].astype(str),
        strict=True,
    ))
    feats_out = config.DATA / "uncultured_features.jsonl"
    print(f"Output: {feats_out}\n")

    with tqdm(total=len(tasks), desc="featurize uncultured", unit="strain") as bar:
        def progress(c, s, t):
            bar.n = c
            bar.set_postfix(success=s, fail=c - s)
            bar.refresh()

        stream_fetch_and_featurize(
            tasks, out_path=feats_out, n_workers=args.workers, on_progress=progress,
        )

    # Materialize parquet
    feats = pd.read_json(feats_out, lines=True)
    feats_parquet = config.DATA / "uncultured_features.parquet"
    feats.to_parquet(feats_parquet, index=False)
    print(f"\nFeaturized {len(feats):,} genomes")

    # Now re-train one final XGBoost on ALL of BacDive (all folds combined) and predict.
    # We do this here (not loading a saved model) so the prediction surface uses every
    # labeled BacDive strain, not just one fold's training set.
    pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
    bacdive_feats = pd.read_parquet(config.DATA / "features.parquet")
    bacdive = pheno.merge(bacdive_feats, on=["bacdive_id", "genome_accession"], how="inner")

    feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
    common_cols = [c for c in feature_cols if c in bacdive.columns]
    print(f"Using {len(common_cols)} features common to BacDive + uncultured")

    output_rows = candidates[[
        "genome_accession", "gtdb_taxonomy", "ncbi_organism_name", "checkm_completeness",
    ]].copy()

    # Index uncultured features by accession for lookup
    feats_by_acc = feats.set_index("genome_accession")

    for target, task in config.PHENOTYPE_TARGETS.items():
        labeled = bacdive[bacdive[target].notna()]
        if len(labeled) < 100:
            continue
        X_train = labeled[common_cols]
        y_train_raw = labeled[target]

        if task == "regression":
            model = xgb.XGBRegressor(
                n_estimators=500, max_depth=5, learning_rate=0.05,
                tree_method="hist", n_jobs=-1,
            )
            y_train = y_train_raw.to_numpy(dtype=float)
            model.fit(X_train, y_train)
            X_pred = feats_by_acc.reindex(candidates["genome_accession"])[common_cols]
            preds = model.predict(X_pred.fillna(0))
            output_rows[f"pred_{target}"] = preds
        else:
            encoder = LabelEncoder()
            y_train_enc = encoder.fit_transform(y_train_raw.astype(str))
            model = xgb.XGBClassifier(
                n_estimators=300, max_depth=5, learning_rate=0.05,
                tree_method="hist", n_jobs=-1, eval_metric="mlogloss",
            )
            model.fit(X_train, y_train_enc)
            X_pred = feats_by_acc.reindex(candidates["genome_accession"])[common_cols]
            preds = model.predict(X_pred.fillna(0))
            preds_proba = model.predict_proba(X_pred.fillna(0))
            output_rows[f"pred_{target}"] = encoder.inverse_transform(preds)
            output_rows[f"pred_{target}_confidence"] = preds_proba.max(axis=1)

        print(f"  {target}: trained on {len(labeled):,} BacDive strains, predicted {len(output_rows)}")

    out_path = config.ARTIFACTS / "uncultured_predictions.parquet"
    output_rows.to_parquet(out_path, index=False)
    print(f"\nWrote {len(output_rows)} predictions to {out_path}")

    # Quick summary
    print("\nPredicted T_opt distribution:")
    if "pred_optimal_temperature_c" in output_rows:
        s = output_rows["pred_optimal_temperature_c"].dropna()
        print(f"  mean={s.mean():.1f}  std={s.std():.1f}  "
              f"p10={s.quantile(0.1):.1f}  p90={s.quantile(0.9):.1f}")
        n_thermo = (s >= 50).sum()
        n_psychro = (s <= 15).sum()
        print(f"  predicted thermophiles (T_opt >= 50°C): {n_thermo}")
        print(f"  predicted psychrophiles (T_opt <= 15°C): {n_psychro}")


if __name__ == "__main__":
    main()