Spaces:
Running
Running
File size: 6,987 Bytes
52cf5ab d082ced 52cf5ab d082ced 52cf5ab d082ced 52cf5ab f0f1d93 52cf5ab d082ced 52cf5ab d082ced 52cf5ab f0f1d93 5df9ef8 0ed74db 52cf5ab 0ed74db 52cf5ab d082ced 52cf5ab 4b79970 bbbea9d 4b79970 d082ced 52cf5ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """Train the multi-task XGBoost baseline.
Joins phenotypes + features, derives a stable group column for GroupKFold, trains, saves
the merged training table for the eval renderer, and writes per-target metrics.
"""
from __future__ import annotations
import time
import pandas as pd
from microbe_model import config
from microbe_model.train.baseline import save_results, train_all
def derive_group(row: pd.Series) -> str:
"""Group-K-fold key. Prefer LPSN family (from BacDive); fall back to genus then species."""
for col in ("family", "genus"):
val = row.get(col)
if isinstance(val, str) and val:
return val
species = row.get("species")
if isinstance(species, str) and species:
return species.split()[0]
return "__unknown__"
def encode_isolation_categories(
df: pd.DataFrame,
*,
min_count: int = 10,
) -> tuple[pd.DataFrame, list[str]]:
"""One-hot encode isolation_cat1/cat2 (pipe-joined multi-labels).
Each strain's category cell is "Tag1|Tag2|..." (or NaN). We split, then create one
iso_<level>_<tag> column per tag that appears in ≥min_count training rows. Strains
without any isolation info get all-zero rows for these features (XGBoost treats this
as "no signal" rather than missing).
"""
new_cols: list[str] = []
for level in ("isolation_cat1", "isolation_cat2"):
if level not in df.columns:
continue
from collections import Counter
tag_counts: Counter[str] = Counter()
for v in df[level].dropna():
tag_counts.update(v.split("|"))
kept = [t for t, n in tag_counts.items() if n >= min_count]
seen_slugs: set[str] = set()
import re
for tag in sorted(kept):
slug = tag.lower().replace(">", "gt").replace("<", "lt")
slug = re.sub(r"[^a-z0-9]+", "_", slug).strip("_")
col = f"iso_{level.split('_')[1]}_{slug}"
if col in seen_slugs:
continue
seen_slugs.add(col)
df[col] = df[level].fillna("").apply(lambda v, t=tag: int(t in v.split("|")))
new_cols.append(col)
return df, new_cols
def main() -> None:
t0 = time.time()
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
feats = pd.read_parquet(config.DATA / "features.parquet")
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
df["group"] = df.apply(derive_group, axis=1)
df, iso_cols = encode_isolation_categories(df)
print(f"Encoded {len(iso_cols)} isolation-category features "
f"({df[iso_cols].sum().sum():.0f} non-zero entries)")
md_path = config.DATA / "mediadive_features.parquet"
md_cols: list[str] = []
if md_path.exists():
md = pd.read_parquet(md_path)
md["bacdive_id"] = md["bacdive_id"].astype(int)
df["bacdive_id"] = df["bacdive_id"].astype(int)
md_cols = [c for c in md.columns if c != "bacdive_id"]
df = df.merge(md, on="bacdive_id", how="left")
n_with_md = df[md_cols[0]].notna().sum() if md_cols else 0
print(f"Joined MediaDive features ({len(md_cols)} cols) — "
f"{n_with_md:,}/{len(df):,} training rows have MediaDive data")
hmm_path = config.DATA / "hmm_features.parquet"
hmm_cols: list[str] = []
if hmm_path.exists():
hmm = pd.read_parquet(hmm_path)
hmm_cols = [c for c in hmm.columns if c != "genome_accession"]
df = df.merge(hmm, on="genome_accession", how="left")
n_with_hmm = df[hmm_cols[0]].notna().sum() if hmm_cols else 0
print(f"Joined HMM features ({len(hmm_cols)} cols) — "
f"{n_with_hmm:,}/{len(df):,} training rows have HMM data")
kegg_path = config.DATA / "kegg_modules.parquet"
kegg_cols: list[str] = []
if kegg_path.exists():
kegg = pd.read_parquet(kegg_path)
kegg_cols = [c for c in kegg.columns if c != "genome_accession"]
df = df.merge(kegg, on="genome_accession", how="left")
n_with_kegg = df[kegg_cols[0]].notna().sum() if kegg_cols else 0
print(f"Joined KEGG module completeness ({len(kegg_cols)} cols) — "
f"{n_with_kegg:,}/{len(df):,} training rows have KEGG data")
pme_path = config.DATA / "per_marker_embeddings.parquet"
pme_cols: list[str] = []
if pme_path.exists():
pme = pd.read_parquet(pme_path)
pme_cols = [c for c in pme.columns if c.startswith("pme_")]
pme_join = pme[["genome_accession"] + pme_cols].drop_duplicates("genome_accession")
df = df.merge(pme_join, on="genome_accession", how="left")
n_with_pme = df[pme_cols[0]].notna().sum() if pme_cols else 0
print(f"Joined per-marker ESM-2 embeddings ({len(pme_cols)} cols) — "
f"{n_with_pme:,}/{len(df):,} training rows have PME data")
iso_meta_path = config.DATA / "isolation_metadata.parquet"
iso_meta_cols: list[str] = []
if iso_meta_path.exists():
iso_meta = pd.read_parquet(iso_meta_path)
iso_meta["bacdive_id"] = iso_meta["bacdive_id"].astype(int)
df["bacdive_id"] = df["bacdive_id"].astype(int)
# Use only the numeric/binary columns; leave free-text out of XGBoost
keep = ["iso_lat", "iso_lon", "iso_collection_year"]
keep += [c for c in iso_meta.columns if c.startswith(("iso_continent_", "iso_country_", "iso_host_kingdom_"))]
iso_meta_cols = [c for c in keep if c in iso_meta.columns]
df = df.merge(iso_meta[["bacdive_id"] + iso_meta_cols], on="bacdive_id", how="left")
print(f"Joined isolation metadata ({len(iso_meta_cols)} cols)")
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
feature_cols = feature_cols + iso_cols + md_cols + hmm_cols + kegg_cols + iso_meta_cols + pme_cols
print(f"Training table: {len(df):,} strains × {len(feature_cols)} features")
print(f"Distinct groups: {df['group'].nunique():,}")
print(f"Group sizes (top 10): {df['group'].value_counts().head(10).to_dict()}")
print()
training_table = config.DATA / "training_table.parquet"
df.to_parquet(training_table, index=False)
print(f"Wrote training table to {training_table}")
results = train_all(df, feature_cols, group_col_override="group")
out = config.ARTIFACTS / "baseline_results.json"
predictions_out = config.ARTIFACTS / "predictions.parquet"
save_results(results, out, predictions_path=predictions_out, feature_cols=feature_cols)
print(f"Wrote per-strain predictions to {predictions_out}")
print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
for target, r in results.items():
if r.folds:
metric = r.folds[0].metric_name
print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})")
else:
print(f" {target:25s} skipped (insufficient data)")
if __name__ == "__main__":
main()
|