Spaces:
Running
Running
File size: 5,993 Bytes
30e65bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """Apply the trained v1 model to uncultured (non-BacDive) GTDB representatives.
Pipeline:
1. Read data/gtdb_candidates.parquet (created by scripts/06_fetch_gtdb_candidates.py)
2. Stream-fetch + featurize each candidate (writes data/uncultured_features.jsonl)
3. Re-train? No — load the existing baseline_results.json + reload an XGBoost model
trained on the BacDive corpus.
The output: artifacts/uncultured_predictions.parquet — one row per genome with all
four phenotype targets predicted from genome features alone. This is the v0 research
deliverable: predicted cultivation conditions for genomes that have never been cultured.
"""
from __future__ import annotations
import argparse
import pandas as pd
import xgboost as xgb
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from microbe_model import config
from microbe_model.pipeline import stream_fetch_and_featurize
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--max", type=int, default=None,
help="Cap how many candidates to process (default: all).")
parser.add_argument("--workers", type=int, default=7)
args = parser.parse_args()
candidates_path = config.DATA / "gtdb_candidates.parquet"
if not candidates_path.exists():
raise SystemExit(f"Missing {candidates_path}. Run scripts/06_fetch_gtdb_candidates.py first.")
candidates = pd.read_parquet(candidates_path)
if args.max:
candidates = candidates.head(args.max)
print(f"Predicting on {len(candidates):,} GTDB representatives not in BacDive")
# Use a synthetic int "fake_id" since these aren't BacDive strains
# (the streaming pipeline expects int IDs but doesn't care what they are)
candidates = candidates.reset_index(drop=True)
candidates["fake_id"] = candidates.index + 1_000_000_000 # avoid collision with BacDive IDs
tasks = list(zip(
candidates["fake_id"].astype(int),
candidates["genome_accession"].astype(str),
strict=True,
))
feats_out = config.DATA / "uncultured_features.jsonl"
print(f"Output: {feats_out}\n")
with tqdm(total=len(tasks), desc="featurize uncultured", unit="strain") as bar:
def progress(c, s, t):
bar.n = c
bar.set_postfix(success=s, fail=c - s)
bar.refresh()
stream_fetch_and_featurize(
tasks, out_path=feats_out, n_workers=args.workers, on_progress=progress,
)
# Materialize parquet
feats = pd.read_json(feats_out, lines=True)
feats_parquet = config.DATA / "uncultured_features.parquet"
feats.to_parquet(feats_parquet, index=False)
print(f"\nFeaturized {len(feats):,} genomes")
# Now re-train one final XGBoost on ALL of BacDive (all folds combined) and predict.
# We do this here (not loading a saved model) so the prediction surface uses every
# labeled BacDive strain, not just one fold's training set.
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
bacdive_feats = pd.read_parquet(config.DATA / "features.parquet")
bacdive = pheno.merge(bacdive_feats, on=["bacdive_id", "genome_accession"], how="inner")
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
common_cols = [c for c in feature_cols if c in bacdive.columns]
print(f"Using {len(common_cols)} features common to BacDive + uncultured")
output_rows = candidates[[
"genome_accession", "gtdb_taxonomy", "ncbi_organism_name", "checkm_completeness",
]].copy()
# Index uncultured features by accession for lookup
feats_by_acc = feats.set_index("genome_accession")
for target, task in config.PHENOTYPE_TARGETS.items():
labeled = bacdive[bacdive[target].notna()]
if len(labeled) < 100:
continue
X_train = labeled[common_cols]
y_train_raw = labeled[target]
if task == "regression":
model = xgb.XGBRegressor(
n_estimators=500, max_depth=5, learning_rate=0.05,
tree_method="hist", n_jobs=-1,
)
y_train = y_train_raw.to_numpy(dtype=float)
model.fit(X_train, y_train)
X_pred = feats_by_acc.reindex(candidates["genome_accession"])[common_cols]
preds = model.predict(X_pred.fillna(0))
output_rows[f"pred_{target}"] = preds
else:
encoder = LabelEncoder()
y_train_enc = encoder.fit_transform(y_train_raw.astype(str))
model = xgb.XGBClassifier(
n_estimators=300, max_depth=5, learning_rate=0.05,
tree_method="hist", n_jobs=-1, eval_metric="mlogloss",
)
model.fit(X_train, y_train_enc)
X_pred = feats_by_acc.reindex(candidates["genome_accession"])[common_cols]
preds = model.predict(X_pred.fillna(0))
preds_proba = model.predict_proba(X_pred.fillna(0))
output_rows[f"pred_{target}"] = encoder.inverse_transform(preds)
output_rows[f"pred_{target}_confidence"] = preds_proba.max(axis=1)
print(f" {target}: trained on {len(labeled):,} BacDive strains, predicted {len(output_rows)}")
out_path = config.ARTIFACTS / "uncultured_predictions.parquet"
output_rows.to_parquet(out_path, index=False)
print(f"\nWrote {len(output_rows)} predictions to {out_path}")
# Quick summary
print("\nPredicted T_opt distribution:")
if "pred_optimal_temperature_c" in output_rows:
s = output_rows["pred_optimal_temperature_c"].dropna()
print(f" mean={s.mean():.1f} std={s.std():.1f} "
f"p10={s.quantile(0.1):.1f} p90={s.quantile(0.9):.1f}")
n_thermo = (s >= 50).sum()
n_psychro = (s <= 15).sum()
print(f" predicted thermophiles (T_opt >= 50°C): {n_thermo}")
print(f" predicted psychrophiles (T_opt <= 15°C): {n_psychro}")
if __name__ == "__main__":
main()
|