"""Apply the trained v1 model to uncultured (non-BacDive) GTDB representatives. Pipeline: 1. Read data/gtdb_candidates.parquet (created by scripts/06_fetch_gtdb_candidates.py) 2. Stream-fetch + featurize each candidate (writes data/uncultured_features.jsonl) 3. Re-train? No — load the existing baseline_results.json + reload an XGBoost model trained on the BacDive corpus. The output: artifacts/uncultured_predictions.parquet — one row per genome with all four phenotype targets predicted from genome features alone. This is the v0 research deliverable: predicted cultivation conditions for genomes that have never been cultured. """ from __future__ import annotations import argparse import pandas as pd import xgboost as xgb from sklearn.preprocessing import LabelEncoder from tqdm import tqdm from microbe_model import config from microbe_model.pipeline import stream_fetch_and_featurize def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--max", type=int, default=None, help="Cap how many candidates to process (default: all).") parser.add_argument("--workers", type=int, default=7) args = parser.parse_args() candidates_path = config.DATA / "gtdb_candidates.parquet" if not candidates_path.exists(): raise SystemExit(f"Missing {candidates_path}. Run scripts/06_fetch_gtdb_candidates.py first.") candidates = pd.read_parquet(candidates_path) if args.max: candidates = candidates.head(args.max) print(f"Predicting on {len(candidates):,} GTDB representatives not in BacDive") # Use a synthetic int "fake_id" since these aren't BacDive strains # (the streaming pipeline expects int IDs but doesn't care what they are) candidates = candidates.reset_index(drop=True) candidates["fake_id"] = candidates.index + 1_000_000_000 # avoid collision with BacDive IDs tasks = list(zip( candidates["fake_id"].astype(int), candidates["genome_accession"].astype(str), strict=True, )) feats_out = config.DATA / "uncultured_features.jsonl" print(f"Output: {feats_out}\n") with tqdm(total=len(tasks), desc="featurize uncultured", unit="strain") as bar: def progress(c, s, t): bar.n = c bar.set_postfix(success=s, fail=c - s) bar.refresh() stream_fetch_and_featurize( tasks, out_path=feats_out, n_workers=args.workers, on_progress=progress, ) # Materialize parquet feats = pd.read_json(feats_out, lines=True) feats_parquet = config.DATA / "uncultured_features.parquet" feats.to_parquet(feats_parquet, index=False) print(f"\nFeaturized {len(feats):,} genomes") # Now re-train one final XGBoost on ALL of BacDive (all folds combined) and predict. # We do this here (not loading a saved model) so the prediction surface uses every # labeled BacDive strain, not just one fold's training set. pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet") bacdive_feats = pd.read_parquet(config.DATA / "features.parquet") bacdive = pheno.merge(bacdive_feats, on=["bacdive_id", "genome_accession"], how="inner") feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}] common_cols = [c for c in feature_cols if c in bacdive.columns] print(f"Using {len(common_cols)} features common to BacDive + uncultured") output_rows = candidates[[ "genome_accession", "gtdb_taxonomy", "ncbi_organism_name", "checkm_completeness", ]].copy() # Index uncultured features by accession for lookup feats_by_acc = feats.set_index("genome_accession") for target, task in config.PHENOTYPE_TARGETS.items(): labeled = bacdive[bacdive[target].notna()] if len(labeled) < 100: continue X_train = labeled[common_cols] y_train_raw = labeled[target] if task == "regression": model = xgb.XGBRegressor( n_estimators=500, max_depth=5, learning_rate=0.05, tree_method="hist", n_jobs=-1, ) y_train = y_train_raw.to_numpy(dtype=float) model.fit(X_train, y_train) X_pred = feats_by_acc.reindex(candidates["genome_accession"])[common_cols] preds = model.predict(X_pred.fillna(0)) output_rows[f"pred_{target}"] = preds else: encoder = LabelEncoder() y_train_enc = encoder.fit_transform(y_train_raw.astype(str)) model = xgb.XGBClassifier( n_estimators=300, max_depth=5, learning_rate=0.05, tree_method="hist", n_jobs=-1, eval_metric="mlogloss", ) model.fit(X_train, y_train_enc) X_pred = feats_by_acc.reindex(candidates["genome_accession"])[common_cols] preds = model.predict(X_pred.fillna(0)) preds_proba = model.predict_proba(X_pred.fillna(0)) output_rows[f"pred_{target}"] = encoder.inverse_transform(preds) output_rows[f"pred_{target}_confidence"] = preds_proba.max(axis=1) print(f" {target}: trained on {len(labeled):,} BacDive strains, predicted {len(output_rows)}") out_path = config.ARTIFACTS / "uncultured_predictions.parquet" output_rows.to_parquet(out_path, index=False) print(f"\nWrote {len(output_rows)} predictions to {out_path}") # Quick summary print("\nPredicted T_opt distribution:") if "pred_optimal_temperature_c" in output_rows: s = output_rows["pred_optimal_temperature_c"].dropna() print(f" mean={s.mean():.1f} std={s.std():.1f} " f"p10={s.quantile(0.1):.1f} p90={s.quantile(0.9):.1f}") n_thermo = (s >= 50).sum() n_psychro = (s <= 15).sum() print(f" predicted thermophiles (T_opt >= 50°C): {n_thermo}") print(f" predicted psychrophiles (T_opt <= 15°C): {n_psychro}") if __name__ == "__main__": main()