microbe-model / scripts /14_train_combined.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Train v3: hand-crafted features (v1) + ESM-2 embeddings (v2) + isolation tags.
Tests whether embeddings carry complementary signal to the curated features even
when they lose head-to-head. Same train/test splits and XGBoost hyperparameters
as v1 and v2.
Reads:
data/bacdive_phenotypes.parquet
data/features.parquet
data/embeddings.parquet
Writes:
artifacts/combined_results.json
"""
from __future__ import annotations
import re
import time
from collections import Counter
import pandas as pd
from microbe_model import config
from microbe_model.train.baseline import save_results, train_all
OXYGEN_COLLAPSE = {
"aerobe": "aerobe",
"obligate aerobe": "aerobe",
"anaerobe": "anaerobe",
"obligate anaerobe": "anaerobe",
"facultative anaerobe": "facultative",
"facultative aerobe": "facultative",
"aerotolerant": "facultative",
"microaerotolerant": "facultative",
"microaerophile": "microaerophile",
}
def derive_group(row: pd.Series) -> str:
for col in ("family", "genus"):
val = row.get(col)
if isinstance(val, str) and val:
return val
species = row.get("species")
if isinstance(species, str) and species:
return species.split()[0]
return "__unknown__"
def encode_isolation_categories(
df: pd.DataFrame,
*,
min_count: int = 10,
) -> tuple[pd.DataFrame, list[str]]:
"""One-hot encode isolation_cat1/cat2 (pipe-joined multi-labels).
Mirrors the encoder in scripts/03_train_baseline.py so v3 sees the same
isolation-tag vocabulary as v1.
"""
new_cols: list[str] = []
for level in ("isolation_cat1", "isolation_cat2"):
if level not in df.columns:
continue
tag_counts: Counter[str] = Counter()
for v in df[level].dropna():
tag_counts.update(v.split("|"))
kept = [t for t, n in tag_counts.items() if n >= min_count]
seen_slugs: set[str] = set()
for tag in sorted(kept):
slug = tag.lower().replace(">", "gt").replace("<", "lt")
slug = re.sub(r"[^a-z0-9]+", "_", slug).strip("_")
col = f"iso_{level.split('_')[1]}_{slug}"
if col in seen_slugs:
continue
seen_slugs.add(col)
df[col] = df[level].fillna("").apply(lambda v, t=tag: int(t in v.split("|")))
new_cols.append(col)
return df, new_cols
def main() -> None:
t0 = time.time()
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
feats = pd.read_parquet(config.DATA / "features.parquet")
embeds = pd.read_parquet(config.DATA / "embeddings.parquet")
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
df = df.merge(embeds, on=["bacdive_id", "genome_accession"], how="inner")
df["group"] = df.apply(derive_group, axis=1)
if "oxygen_requirement" in df.columns:
before = df["oxygen_requirement"].value_counts().to_dict()
df["oxygen_requirement"] = df["oxygen_requirement"].map(OXYGEN_COLLAPSE).fillna(df["oxygen_requirement"])
after = df["oxygen_requirement"].value_counts().to_dict()
print(f"Oxygen labels collapsed: {len(before)}{len(after)} classes")
print(f" After: {after}")
df, iso_cols = encode_isolation_categories(df)
print(f"Encoded {len(iso_cols)} isolation-category features "
f"({df[iso_cols].sum().sum():.0f} non-zero entries)")
v1_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
v2_cols = [c for c in embeds.columns if c.startswith("emb_")]
feature_cols = v1_cols + v2_cols + iso_cols
print(f"Training table: {len(df):,} strains × {len(feature_cols)} features "
f"({len(v1_cols)} hand-crafted + {len(v2_cols)} embedding dims + {len(iso_cols)} iso tags)")
print(f"Distinct groups: {df['group'].nunique():,}")
print()
results = train_all(df, feature_cols, group_col_override="group")
out = config.ARTIFACTS / "combined_collapsed_results.json"
predictions_out = config.ARTIFACTS / "combined_collapsed_predictions.parquet"
save_results(results, out, predictions_path=predictions_out, feature_cols=feature_cols)
print(f"\nTrained in {time.time() - t0:.1f}s. Wrote {out}\n")
print("Results summary:")
for target, r in results.items():
if r.folds:
metric = r.folds[0].metric_name
print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})")
else:
print(f" {target:25s} skipped")
if __name__ == "__main__":
main()