Spaces:
Running
Running
File size: 4,603 Bytes
f0f1d93 d23315e f0f1d93 d23315e f0f1d93 d23315e 0ed74db d23315e f0f1d93 d23315e 0ed74db f0f1d93 d23315e f0f1d93 d23315e f0f1d93 d23315e 0ed74db d23315e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Train v3: hand-crafted features (v1) + ESM-2 embeddings (v2) + isolation tags.
Tests whether embeddings carry complementary signal to the curated features even
when they lose head-to-head. Same train/test splits and XGBoost hyperparameters
as v1 and v2.
Reads:
data/bacdive_phenotypes.parquet
data/features.parquet
data/embeddings.parquet
Writes:
artifacts/combined_results.json
"""
from __future__ import annotations
import re
import time
from collections import Counter
import pandas as pd
from microbe_model import config
from microbe_model.train.baseline import save_results, train_all
OXYGEN_COLLAPSE = {
"aerobe": "aerobe",
"obligate aerobe": "aerobe",
"anaerobe": "anaerobe",
"obligate anaerobe": "anaerobe",
"facultative anaerobe": "facultative",
"facultative aerobe": "facultative",
"aerotolerant": "facultative",
"microaerotolerant": "facultative",
"microaerophile": "microaerophile",
}
def derive_group(row: pd.Series) -> str:
for col in ("family", "genus"):
val = row.get(col)
if isinstance(val, str) and val:
return val
species = row.get("species")
if isinstance(species, str) and species:
return species.split()[0]
return "__unknown__"
def encode_isolation_categories(
df: pd.DataFrame,
*,
min_count: int = 10,
) -> tuple[pd.DataFrame, list[str]]:
"""One-hot encode isolation_cat1/cat2 (pipe-joined multi-labels).
Mirrors the encoder in scripts/03_train_baseline.py so v3 sees the same
isolation-tag vocabulary as v1.
"""
new_cols: list[str] = []
for level in ("isolation_cat1", "isolation_cat2"):
if level not in df.columns:
continue
tag_counts: Counter[str] = Counter()
for v in df[level].dropna():
tag_counts.update(v.split("|"))
kept = [t for t, n in tag_counts.items() if n >= min_count]
seen_slugs: set[str] = set()
for tag in sorted(kept):
slug = tag.lower().replace(">", "gt").replace("<", "lt")
slug = re.sub(r"[^a-z0-9]+", "_", slug).strip("_")
col = f"iso_{level.split('_')[1]}_{slug}"
if col in seen_slugs:
continue
seen_slugs.add(col)
df[col] = df[level].fillna("").apply(lambda v, t=tag: int(t in v.split("|")))
new_cols.append(col)
return df, new_cols
def main() -> None:
t0 = time.time()
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
feats = pd.read_parquet(config.DATA / "features.parquet")
embeds = pd.read_parquet(config.DATA / "embeddings.parquet")
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
df = df.merge(embeds, on=["bacdive_id", "genome_accession"], how="inner")
df["group"] = df.apply(derive_group, axis=1)
if "oxygen_requirement" in df.columns:
before = df["oxygen_requirement"].value_counts().to_dict()
df["oxygen_requirement"] = df["oxygen_requirement"].map(OXYGEN_COLLAPSE).fillna(df["oxygen_requirement"])
after = df["oxygen_requirement"].value_counts().to_dict()
print(f"Oxygen labels collapsed: {len(before)} → {len(after)} classes")
print(f" After: {after}")
df, iso_cols = encode_isolation_categories(df)
print(f"Encoded {len(iso_cols)} isolation-category features "
f"({df[iso_cols].sum().sum():.0f} non-zero entries)")
v1_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
v2_cols = [c for c in embeds.columns if c.startswith("emb_")]
feature_cols = v1_cols + v2_cols + iso_cols
print(f"Training table: {len(df):,} strains × {len(feature_cols)} features "
f"({len(v1_cols)} hand-crafted + {len(v2_cols)} embedding dims + {len(iso_cols)} iso tags)")
print(f"Distinct groups: {df['group'].nunique():,}")
print()
results = train_all(df, feature_cols, group_col_override="group")
out = config.ARTIFACTS / "combined_collapsed_results.json"
predictions_out = config.ARTIFACTS / "combined_collapsed_predictions.parquet"
save_results(results, out, predictions_path=predictions_out, feature_cols=feature_cols)
print(f"\nTrained in {time.time() - t0:.1f}s. Wrote {out}\n")
print("Results summary:")
for target, r in results.items():
if r.folds:
metric = r.folds[0].metric_name
print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})")
else:
print(f" {target:25s} skipped")
if __name__ == "__main__":
main()
|