File size: 2,157 Bytes
8c28a61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Train v2 baseline using ESM-2 embeddings as features.

Compares head-to-head with the v0/v1 hand-crafted features. Output is a delta
table — does the embedding really beat the curated features, and on which targets?

Run after scripts/11_extract_embeddings.py has produced data/embeddings.parquet.
"""
from __future__ import annotations

import time

import pandas as pd

from microbe_model import config
from microbe_model.train.baseline import save_results, train_all


def derive_group(row: pd.Series) -> str:
    for col in ("family", "genus"):
        val = row.get(col)
        if isinstance(val, str) and val:
            return val
    species = row.get("species")
    if isinstance(species, str) and species:
        return species.split()[0]
    return "__unknown__"


def main() -> None:
    pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
    embed_path = config.DATA / "embeddings.parquet"
    if not embed_path.exists():
        raise SystemExit(f"Missing {embed_path}. Run scripts/11_extract_embeddings.py first.")
    embeds = pd.read_parquet(embed_path)

    df = pheno.merge(embeds, on=["bacdive_id", "genome_accession"], how="inner")
    df["group"] = df.apply(derive_group, axis=1)

    feature_cols = [c for c in embeds.columns if c.startswith("emb_")]
    print(f"Training table: {len(df):,} strains × {len(feature_cols)} embedding dims")
    print(f"Distinct groups: {df['group'].nunique():,}")
    print()

    df.to_parquet(config.DATA / "training_table_embeddings.parquet", index=False)
    t0 = time.time()
    results = train_all(df, feature_cols, group_col_override="group")

    out = config.ARTIFACTS / "embedding_results.json"
    save_results(results, out, feature_cols=feature_cols)
    print(f"\nTrained in {time.time() - t0:.1f}s. Wrote {out}\n")

    print("Comparison vs v1 hand-crafted features:")
    for target, r in results.items():
        if r.folds:
            metric = r.folds[0].metric_name
            print(f"  {target:25s} {metric:10s} = {r.mean():.4f}  (n_folds={len(r.folds)})")
        else:
            print(f"  {target:25s} skipped")


if __name__ == "__main__":
    main()