microbe-model / scripts /12_train_with_embeddings.py
Miyu Horiuchi
v2 scaffolding: ESM-2 embedding extraction + GPU runner doc
8c28a61
"""Train v2 baseline using ESM-2 embeddings as features.
Compares head-to-head with the v0/v1 hand-crafted features. Output is a delta
table — does the embedding really beat the curated features, and on which targets?
Run after scripts/11_extract_embeddings.py has produced data/embeddings.parquet.
"""
from __future__ import annotations
import time
import pandas as pd
from microbe_model import config
from microbe_model.train.baseline import save_results, train_all
def derive_group(row: pd.Series) -> str:
for col in ("family", "genus"):
val = row.get(col)
if isinstance(val, str) and val:
return val
species = row.get("species")
if isinstance(species, str) and species:
return species.split()[0]
return "__unknown__"
def main() -> None:
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
embed_path = config.DATA / "embeddings.parquet"
if not embed_path.exists():
raise SystemExit(f"Missing {embed_path}. Run scripts/11_extract_embeddings.py first.")
embeds = pd.read_parquet(embed_path)
df = pheno.merge(embeds, on=["bacdive_id", "genome_accession"], how="inner")
df["group"] = df.apply(derive_group, axis=1)
feature_cols = [c for c in embeds.columns if c.startswith("emb_")]
print(f"Training table: {len(df):,} strains × {len(feature_cols)} embedding dims")
print(f"Distinct groups: {df['group'].nunique():,}")
print()
df.to_parquet(config.DATA / "training_table_embeddings.parquet", index=False)
t0 = time.time()
results = train_all(df, feature_cols, group_col_override="group")
out = config.ARTIFACTS / "embedding_results.json"
save_results(results, out, feature_cols=feature_cols)
print(f"\nTrained in {time.time() - t0:.1f}s. Wrote {out}\n")
print("Comparison vs v1 hand-crafted features:")
for target, r in results.items():
if r.folds:
metric = r.folds[0].metric_name
print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})")
else:
print(f" {target:25s} skipped")
if __name__ == "__main__":
main()