File size: 4,603 Bytes
f0f1d93
d23315e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0f1d93
d23315e
f0f1d93
d23315e
 
 
 
 
 
 
0ed74db
 
 
 
 
 
 
 
 
 
 
 
 
d23315e
 
 
 
 
 
 
 
 
 
 
f0f1d93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d23315e
 
 
 
 
 
 
 
 
 
0ed74db
 
 
 
 
 
 
f0f1d93
 
 
 
d23315e
 
f0f1d93
d23315e
 
f0f1d93
d23315e
 
 
 
 
0ed74db
 
 
d23315e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Train v3: hand-crafted features (v1) + ESM-2 embeddings (v2) + isolation tags.

Tests whether embeddings carry complementary signal to the curated features even
when they lose head-to-head. Same train/test splits and XGBoost hyperparameters
as v1 and v2.

Reads:
  data/bacdive_phenotypes.parquet
  data/features.parquet
  data/embeddings.parquet

Writes:
  artifacts/combined_results.json
"""
from __future__ import annotations

import re
import time
from collections import Counter

import pandas as pd

from microbe_model import config
from microbe_model.train.baseline import save_results, train_all


OXYGEN_COLLAPSE = {
    "aerobe": "aerobe",
    "obligate aerobe": "aerobe",
    "anaerobe": "anaerobe",
    "obligate anaerobe": "anaerobe",
    "facultative anaerobe": "facultative",
    "facultative aerobe": "facultative",
    "aerotolerant": "facultative",
    "microaerotolerant": "facultative",
    "microaerophile": "microaerophile",
}


def derive_group(row: pd.Series) -> str:
    for col in ("family", "genus"):
        val = row.get(col)
        if isinstance(val, str) and val:
            return val
    species = row.get("species")
    if isinstance(species, str) and species:
        return species.split()[0]
    return "__unknown__"


def encode_isolation_categories(
    df: pd.DataFrame,
    *,
    min_count: int = 10,
) -> tuple[pd.DataFrame, list[str]]:
    """One-hot encode isolation_cat1/cat2 (pipe-joined multi-labels).

    Mirrors the encoder in scripts/03_train_baseline.py so v3 sees the same
    isolation-tag vocabulary as v1.
    """
    new_cols: list[str] = []
    for level in ("isolation_cat1", "isolation_cat2"):
        if level not in df.columns:
            continue
        tag_counts: Counter[str] = Counter()
        for v in df[level].dropna():
            tag_counts.update(v.split("|"))
        kept = [t for t, n in tag_counts.items() if n >= min_count]
        seen_slugs: set[str] = set()
        for tag in sorted(kept):
            slug = tag.lower().replace(">", "gt").replace("<", "lt")
            slug = re.sub(r"[^a-z0-9]+", "_", slug).strip("_")
            col = f"iso_{level.split('_')[1]}_{slug}"
            if col in seen_slugs:
                continue
            seen_slugs.add(col)
            df[col] = df[level].fillna("").apply(lambda v, t=tag: int(t in v.split("|")))
            new_cols.append(col)
    return df, new_cols


def main() -> None:
    t0 = time.time()
    pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
    feats = pd.read_parquet(config.DATA / "features.parquet")
    embeds = pd.read_parquet(config.DATA / "embeddings.parquet")

    df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
    df = df.merge(embeds, on=["bacdive_id", "genome_accession"], how="inner")
    df["group"] = df.apply(derive_group, axis=1)

    if "oxygen_requirement" in df.columns:
        before = df["oxygen_requirement"].value_counts().to_dict()
        df["oxygen_requirement"] = df["oxygen_requirement"].map(OXYGEN_COLLAPSE).fillna(df["oxygen_requirement"])
        after = df["oxygen_requirement"].value_counts().to_dict()
        print(f"Oxygen labels collapsed: {len(before)}{len(after)} classes")
        print(f"  After: {after}")

    df, iso_cols = encode_isolation_categories(df)
    print(f"Encoded {len(iso_cols)} isolation-category features "
          f"({df[iso_cols].sum().sum():.0f} non-zero entries)")

    v1_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
    v2_cols = [c for c in embeds.columns if c.startswith("emb_")]
    feature_cols = v1_cols + v2_cols + iso_cols

    print(f"Training table: {len(df):,} strains × {len(feature_cols)} features "
          f"({len(v1_cols)} hand-crafted + {len(v2_cols)} embedding dims + {len(iso_cols)} iso tags)")
    print(f"Distinct groups: {df['group'].nunique():,}")
    print()

    results = train_all(df, feature_cols, group_col_override="group")

    out = config.ARTIFACTS / "combined_collapsed_results.json"
    predictions_out = config.ARTIFACTS / "combined_collapsed_predictions.parquet"
    save_results(results, out, predictions_path=predictions_out, feature_cols=feature_cols)
    print(f"\nTrained in {time.time() - t0:.1f}s. Wrote {out}\n")

    print("Results summary:")
    for target, r in results.items():
        if r.folds:
            metric = r.folds[0].metric_name
            print(f"  {target:25s} {metric:10s} = {r.mean():.4f}  (n_folds={len(r.folds)})")
        else:
            print(f"  {target:25s} skipped")


if __name__ == "__main__":
    main()