Spaces:
Running
Running
| """Train v2 baseline using ESM-2 embeddings as features. | |
| Compares head-to-head with the v0/v1 hand-crafted features. Output is a delta | |
| table — does the embedding really beat the curated features, and on which targets? | |
| Run after scripts/11_extract_embeddings.py has produced data/embeddings.parquet. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import pandas as pd | |
| from microbe_model import config | |
| from microbe_model.train.baseline import save_results, train_all | |
| def derive_group(row: pd.Series) -> str: | |
| for col in ("family", "genus"): | |
| val = row.get(col) | |
| if isinstance(val, str) and val: | |
| return val | |
| species = row.get("species") | |
| if isinstance(species, str) and species: | |
| return species.split()[0] | |
| return "__unknown__" | |
| def main() -> None: | |
| pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet") | |
| embed_path = config.DATA / "embeddings.parquet" | |
| if not embed_path.exists(): | |
| raise SystemExit(f"Missing {embed_path}. Run scripts/11_extract_embeddings.py first.") | |
| embeds = pd.read_parquet(embed_path) | |
| df = pheno.merge(embeds, on=["bacdive_id", "genome_accession"], how="inner") | |
| df["group"] = df.apply(derive_group, axis=1) | |
| feature_cols = [c for c in embeds.columns if c.startswith("emb_")] | |
| print(f"Training table: {len(df):,} strains × {len(feature_cols)} embedding dims") | |
| print(f"Distinct groups: {df['group'].nunique():,}") | |
| print() | |
| df.to_parquet(config.DATA / "training_table_embeddings.parquet", index=False) | |
| t0 = time.time() | |
| results = train_all(df, feature_cols, group_col_override="group") | |
| out = config.ARTIFACTS / "embedding_results.json" | |
| save_results(results, out, feature_cols=feature_cols) | |
| print(f"\nTrained in {time.time() - t0:.1f}s. Wrote {out}\n") | |
| print("Comparison vs v1 hand-crafted features:") | |
| for target, r in results.items(): | |
| if r.folds: | |
| metric = r.folds[0].metric_name | |
| print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})") | |
| else: | |
| print(f" {target:25s} skipped") | |
| if __name__ == "__main__": | |
| main() | |