Spaces:
Running
Add unified strain catalog (100K rows w/ provenance) + selective weak supervision for pH
Browse filesThree new artifacts:
scripts/21_build_strain_catalog.py
Emits data/strain_catalog.parquet — every BacDive strain (n=100,866) with each
phenotype as {value, source}. source ∈ {bacdive, mediadive_weak, unknown}.
Coverage: T_opt 50% (BacDive only), pH 29% (5,794 BacDive + 23,574 MediaDive),
oxygen 23%, salt 30% (4,242 BacDive + 26,055 MediaDive). 50K strains have an
explicitly 'unknown' temperature, the largest single bucket.
scripts/23_weak_label_apples_to_apples.py + artifacts/weak_label_test.log
Honest semi-supervised test: train on (curated + MediaDive-weak), evaluate on
held-out *curated* test rows only — does weak supervision help generalization
to the gold-standard distribution?
optimal_ph curated 0.5133 → curated+weak 0.4934 (-3.9%, HELPS)
salt_tolerance_pct curated 2.1060 → curated+weak 2.1859 (+3.8%, HURTS)
Matches the pre-experiment correlation probe: pH↔MediaDive corr 0.62 helps,
salt↔MediaDive corr 0.42 hurts.
scripts/15_train_phenotype_heads.py
Backfills pH labels from MediaDive (5,103 → 26,181) before training the deployed
pH quantile heads. Salt deliberately stays curated-only. Re-saved .ubj heads.
scripts/22_train_with_weak_labels.py kept for reproducibility (full-table weak
training; superseded by scripts/23 for the rigorous comparison).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- models/phenotype/optimal_ph_q10.ubj +2 -2
- models/phenotype/optimal_ph_q50.ubj +2 -2
- models/phenotype/optimal_ph_q90.ubj +2 -2
- scripts/15_train_phenotype_heads.py +17 -0
- scripts/21_build_strain_catalog.py +104 -0
- scripts/22_train_with_weak_labels.py +74 -0
- scripts/23_weak_label_apples_to_apples.py +128 -0
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:16498ff43324f45392404520ef74f6e29f3f40572b57662247f4c03f46c2401d
|
| 3 |
+
size 937045
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e9682110e2d5ee8e1e8bd36c9dc48911566f946df2a2b8b656d658290e471bf
|
| 3 |
+
size 1058561
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c680f7a908fa9ffe63f63c376a734799f5df5fd6a37baf0de52900d2e95a3cbc
|
| 3 |
+
size 904405
|
|
@@ -38,6 +38,23 @@ def main() -> None:
|
|
| 38 |
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 39 |
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
out_dir = config.ROOT / "models" / "phenotype"
|
| 42 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 43 |
(out_dir / "feature_cols.json").write_text(json.dumps(feature_cols))
|
|
|
|
| 38 |
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 39 |
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
| 40 |
|
| 41 |
+
# Backfill pH with MediaDive-derived weak labels where BacDive has none.
|
| 42 |
+
# Apples-to-apples test (scripts/23) showed this nets -3.9% MAE on held-out
|
| 43 |
+
# curated test rows (corr 0.62 with optima). Salt did NOT pass the same test
|
| 44 |
+
# (+3.8% MAE, corr only 0.42), so we deliberately don't backfill salt.
|
| 45 |
+
catalog_path = config.DATA / "strain_catalog.parquet"
|
| 46 |
+
if catalog_path.exists():
|
| 47 |
+
catalog = pd.read_parquet(catalog_path)[["bacdive_id", "optimal_ph", "optimal_ph_source"]]
|
| 48 |
+
catalog["bacdive_id"] = catalog["bacdive_id"].astype(int)
|
| 49 |
+
catalog["optimal_ph"] = pd.to_numeric(catalog["optimal_ph"], errors="coerce")
|
| 50 |
+
df["bacdive_id"] = df["bacdive_id"].astype(int)
|
| 51 |
+
df = df.merge(catalog, on="bacdive_id", how="left", suffixes=("", "_cat"))
|
| 52 |
+
n_before = df["optimal_ph"].notna().sum()
|
| 53 |
+
ph_missing = df["optimal_ph"].isna() & df["optimal_ph_cat"].notna() & df["optimal_ph_source"].eq("mediadive_weak")
|
| 54 |
+
df.loc[ph_missing, "optimal_ph"] = df.loc[ph_missing, "optimal_ph_cat"]
|
| 55 |
+
n_after = df["optimal_ph"].notna().sum()
|
| 56 |
+
print(f"pH labels: {n_before:,} curated → {n_after:,} after MediaDive backfill (+{n_after - n_before:,})")
|
| 57 |
+
|
| 58 |
out_dir = config.ROOT / "models" / "phenotype"
|
| 59 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 60 |
(out_dir / "feature_cols.json").write_text(json.dumps(feature_cols))
|
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build a unified per-strain catalog with label provenance.
|
| 2 |
+
|
| 3 |
+
Every BacDive strain in data/bacdive_phenotypes.parquet gets a row. For each of the
|
| 4 |
+
4 phenotype targets, two columns are emitted:
|
| 5 |
+
- <target> numeric/categorical value (or NaN if unknown)
|
| 6 |
+
- <target>_source one of: 'bacdive' | 'mediadive_weak' | 'unknown'
|
| 7 |
+
|
| 8 |
+
`bacdive` means the value came from BacDive's curated optimum / oxygen tolerance
|
| 9 |
+
(high-quality). `mediadive_weak` means we derived it from the median pH/NaCl% of
|
| 10 |
+
the DSMZ media the strain has been recorded as growing on (lower-quality fallback,
|
| 11 |
+
only filled in when BacDive has no value). `unknown` means we have nothing.
|
| 12 |
+
|
| 13 |
+
Saves to data/strain_catalog.parquet.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
from microbe_model import config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main() -> None:
|
| 23 |
+
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet").copy()
|
| 24 |
+
pheno["bacdive_id"] = pheno["bacdive_id"].astype(int)
|
| 25 |
+
|
| 26 |
+
# Pull MediaDive-derived signals (per-strain median across grown media)
|
| 27 |
+
md_path = config.DATA / "mediadive_features.parquet"
|
| 28 |
+
md = pd.read_parquet(md_path) if md_path.exists() else pd.DataFrame()
|
| 29 |
+
if len(md):
|
| 30 |
+
md["bacdive_id"] = md["bacdive_id"].astype(int)
|
| 31 |
+
|
| 32 |
+
# Build per-target value + source
|
| 33 |
+
out_cols: dict[str, pd.Series] = {}
|
| 34 |
+
for col in ("bacdive_id", "species", "genus", "family", "ncbi_taxon_id",
|
| 35 |
+
"genome_accession", "genome_source"):
|
| 36 |
+
if col in pheno.columns:
|
| 37 |
+
out_cols[col] = pheno[col]
|
| 38 |
+
|
| 39 |
+
# Direct BacDive labels (canonical). For each, fall back to MediaDive when missing.
|
| 40 |
+
targets = {
|
| 41 |
+
"optimal_temperature_c": None, # no MediaDive proxy for temperature
|
| 42 |
+
"optimal_ph": "md_ph_median" if len(md) else None,
|
| 43 |
+
"oxygen_requirement": None, # no media-based proxy
|
| 44 |
+
"salt_tolerance_pct": "md_nacl_pct_median" if len(md) else None,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
md_indexed = md.set_index("bacdive_id") if len(md) else None
|
| 48 |
+
for target, md_col in targets.items():
|
| 49 |
+
bacdive_vals = pheno.set_index("bacdive_id")[target] if target in pheno.columns else None
|
| 50 |
+
# Reindex to match pheno order
|
| 51 |
+
ordered_bacdive = (
|
| 52 |
+
bacdive_vals.reindex(pheno["bacdive_id"]).values if bacdive_vals is not None else None
|
| 53 |
+
)
|
| 54 |
+
# MediaDive-derived fallback (numeric only; salt% capped at 30 already)
|
| 55 |
+
weak_vals = (
|
| 56 |
+
md_indexed[md_col].reindex(pheno["bacdive_id"]).values
|
| 57 |
+
if md_col and md_indexed is not None and md_col in md_indexed.columns
|
| 58 |
+
else None
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
values = []
|
| 62 |
+
sources = []
|
| 63 |
+
for i in range(len(pheno)):
|
| 64 |
+
v = ordered_bacdive[i] if ordered_bacdive is not None else None
|
| 65 |
+
# pandas-aware NaN check
|
| 66 |
+
v_is_na = (v is None) or (isinstance(v, float) and pd.isna(v)) or (
|
| 67 |
+
isinstance(v, str) and not v
|
| 68 |
+
)
|
| 69 |
+
if not v_is_na:
|
| 70 |
+
values.append(v)
|
| 71 |
+
sources.append("bacdive")
|
| 72 |
+
continue
|
| 73 |
+
wv = weak_vals[i] if weak_vals is not None else None
|
| 74 |
+
wv_is_na = wv is None or (isinstance(wv, float) and pd.isna(wv))
|
| 75 |
+
if not wv_is_na:
|
| 76 |
+
values.append(wv)
|
| 77 |
+
sources.append("mediadive_weak")
|
| 78 |
+
continue
|
| 79 |
+
values.append(None)
|
| 80 |
+
sources.append("unknown")
|
| 81 |
+
|
| 82 |
+
out_cols[target] = pd.Series(values, dtype="object")
|
| 83 |
+
out_cols[f"{target}_source"] = pd.Series(sources, dtype="string")
|
| 84 |
+
|
| 85 |
+
catalog = pd.DataFrame(out_cols)
|
| 86 |
+
out = config.DATA / "strain_catalog.parquet"
|
| 87 |
+
catalog.to_parquet(out, index=False)
|
| 88 |
+
print(f"wrote {len(catalog):,} strains to {out}\n")
|
| 89 |
+
|
| 90 |
+
# Summary per target
|
| 91 |
+
for target in targets:
|
| 92 |
+
src_col = f"{target}_source"
|
| 93 |
+
counts = catalog[src_col].value_counts().to_dict()
|
| 94 |
+
n_known = counts.get("bacdive", 0) + counts.get("mediadive_weak", 0)
|
| 95 |
+
print(f"{target}")
|
| 96 |
+
print(f" bacdive (curated): {counts.get('bacdive', 0):>7,}")
|
| 97 |
+
print(f" mediadive_weak: {counts.get('mediadive_weak', 0):>7,}")
|
| 98 |
+
print(f" unknown: {counts.get('unknown', 0):>7,}")
|
| 99 |
+
print(f" ─ any-known: {n_known:>7,} ({100*n_known/len(catalog):.0f}%)")
|
| 100 |
+
print()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train baseline using BacDive-first + MediaDive-weak fallback labels for pH/salt.
|
| 2 |
+
|
| 3 |
+
Compares vs the curated-only baseline (artifacts/baseline_results.json) to see whether
|
| 4 |
+
the weak labels are net-helpful for the model. T_opt and oxygen are unaffected (no
|
| 5 |
+
weak source). pH and salt get many more training rows but with noisier labels.
|
| 6 |
+
|
| 7 |
+
Output: artifacts/baseline_results_weak.json
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
from microbe_model import config
|
| 16 |
+
from microbe_model.train.baseline import save_results, train_all
|
| 17 |
+
|
| 18 |
+
# Reuse the encoders from scripts/03 — copy locally to avoid sys.path gymnastics
|
| 19 |
+
import importlib.util
|
| 20 |
+
spec = importlib.util.spec_from_file_location("train03", config.ROOT / "scripts" / "03_train_baseline.py")
|
| 21 |
+
train03 = importlib.util.module_from_spec(spec)
|
| 22 |
+
spec.loader.exec_module(train03)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main() -> None:
|
| 26 |
+
t0 = time.time()
|
| 27 |
+
catalog = pd.read_parquet(config.DATA / "strain_catalog.parquet")
|
| 28 |
+
feats = pd.read_parquet(config.DATA / "features.parquet")
|
| 29 |
+
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
|
| 30 |
+
# Isolation categories live in the original phenotype table — graft them onto the catalog
|
| 31 |
+
iso_in = [c for c in pheno.columns if c.startswith("isolation_cat")]
|
| 32 |
+
pheno_iso = pheno[["bacdive_id"] + iso_in].copy()
|
| 33 |
+
pheno_iso["bacdive_id"] = pheno_iso["bacdive_id"].astype(int)
|
| 34 |
+
|
| 35 |
+
catalog["bacdive_id"] = catalog["bacdive_id"].astype(int)
|
| 36 |
+
feats["bacdive_id"] = feats["bacdive_id"].astype(int)
|
| 37 |
+
catalog = catalog.merge(pheno_iso, on="bacdive_id", how="left")
|
| 38 |
+
df = catalog.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 39 |
+
df["group"] = df.apply(train03.derive_group, axis=1)
|
| 40 |
+
|
| 41 |
+
# Coerce numeric targets (catalog stored as object)
|
| 42 |
+
for col in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
|
| 43 |
+
if col in df.columns:
|
| 44 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 45 |
+
|
| 46 |
+
df, iso_cols = train03.encode_isolation_categories(df)
|
| 47 |
+
print(f"Encoded {len(iso_cols)} isolation features")
|
| 48 |
+
|
| 49 |
+
# IMPORTANT: do NOT add MediaDive features here. The weak labels for pH and salt
|
| 50 |
+
# are *derived from those same features* (per-strain median across DSMZ media), so
|
| 51 |
+
# including them as inputs leaks the target — the model trivially predicts the
|
| 52 |
+
# matching feature column. The honest test is: do MediaDive-derived weak labels
|
| 53 |
+
# help a genome+isolation model generalize?
|
| 54 |
+
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
| 55 |
+
feature_cols = feature_cols + iso_cols
|
| 56 |
+
|
| 57 |
+
print(f"\nTraining table: {len(df):,} strains × {len(feature_cols)} features")
|
| 58 |
+
for tgt in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
|
| 59 |
+
n = df[tgt].notna().sum() if tgt in df.columns else 0
|
| 60 |
+
print(f" {tgt:25s} labeled={n:>6,}")
|
| 61 |
+
print()
|
| 62 |
+
|
| 63 |
+
results = train_all(df, feature_cols, group_col_override="group")
|
| 64 |
+
out = config.ARTIFACTS / "baseline_results_weak.json"
|
| 65 |
+
save_results(results, out, predictions_path=None, feature_cols=feature_cols)
|
| 66 |
+
|
| 67 |
+
print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
|
| 68 |
+
for target, r in results.items():
|
| 69 |
+
if r.folds:
|
| 70 |
+
print(f" {target:25s} {r.folds[0].metric_name:10s} = {r.mean():.4f}")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Honest test: do MediaDive-derived weak labels help generalize to BacDive-curated optima?
|
| 2 |
+
|
| 3 |
+
For pH and salt, two training regimes — same held-out test rows (curated only):
|
| 4 |
+
A. CURATED-ONLY: train on BacDive curated labels.
|
| 5 |
+
B. CURATED + WEAK: train on BacDive curated + MediaDive-derived weak labels.
|
| 6 |
+
|
| 7 |
+
In both, the test set per fold is the *intersection* of the held-out group with the
|
| 8 |
+
curated subset. This isolates whether weak labels help the model do better on the
|
| 9 |
+
gold-standard distribution, rather than just helping it predict the medium pH/salt
|
| 10 |
+
of the strain itself (which would be circular).
|
| 11 |
+
|
| 12 |
+
No MediaDive features are used (deployed model parity).
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import importlib.util
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import xgboost as xgb
|
| 24 |
+
from sklearn.metrics import mean_absolute_error
|
| 25 |
+
from sklearn.model_selection import GroupKFold
|
| 26 |
+
|
| 27 |
+
from microbe_model import config
|
| 28 |
+
|
| 29 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 30 |
+
sys.path.insert(0, str(ROOT / "scripts"))
|
| 31 |
+
|
| 32 |
+
spec = importlib.util.spec_from_file_location("train03", ROOT / "scripts" / "03_train_baseline.py")
|
| 33 |
+
train03 = importlib.util.module_from_spec(spec)
|
| 34 |
+
spec.loader.exec_module(train03)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def cv_mae(
|
| 38 |
+
df: pd.DataFrame,
|
| 39 |
+
feature_cols: list[str],
|
| 40 |
+
target: str,
|
| 41 |
+
*,
|
| 42 |
+
train_mask: pd.Series,
|
| 43 |
+
test_mask: pd.Series,
|
| 44 |
+
n_splits: int = 5,
|
| 45 |
+
) -> tuple[float, int]:
|
| 46 |
+
"""5-fold GroupKFold by family. Train on (train_mask & target.notna() & not test fold).
|
| 47 |
+
Evaluate on (test_mask & target.notna() & test fold). Returns mean MAE across folds.
|
| 48 |
+
"""
|
| 49 |
+
eligible = df[df[target].notna()].copy()
|
| 50 |
+
eligible[target] = pd.to_numeric(eligible[target], errors="coerce")
|
| 51 |
+
eligible = eligible[eligible[target].notna()]
|
| 52 |
+
groups = eligible["group"].fillna("__unknown__")
|
| 53 |
+
splits = min(n_splits, max(2, groups.nunique()))
|
| 54 |
+
kf = GroupKFold(n_splits=splits)
|
| 55 |
+
maes = []
|
| 56 |
+
n_eval_total = 0
|
| 57 |
+
for tr_idx, te_idx in kf.split(eligible, eligible[target], groups):
|
| 58 |
+
tr = eligible.iloc[tr_idx]
|
| 59 |
+
te = eligible.iloc[te_idx]
|
| 60 |
+
# Apply masks to the row indices we're using
|
| 61 |
+
tr = tr[train_mask.reindex(tr.index, fill_value=False).values]
|
| 62 |
+
te = te[test_mask.reindex(te.index, fill_value=False).values]
|
| 63 |
+
if len(tr) < 100 or len(te) < 50:
|
| 64 |
+
continue
|
| 65 |
+
m = xgb.XGBRegressor(
|
| 66 |
+
n_estimators=400, max_depth=5, learning_rate=0.05,
|
| 67 |
+
tree_method="hist", n_jobs=-1,
|
| 68 |
+
)
|
| 69 |
+
m.fit(tr[feature_cols], tr[target].astype(float))
|
| 70 |
+
preds = m.predict(te[feature_cols])
|
| 71 |
+
maes.append(mean_absolute_error(te[target].astype(float), preds))
|
| 72 |
+
n_eval_total += len(te)
|
| 73 |
+
return (float(np.mean(maes)) if maes else float("nan")), n_eval_total
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main() -> None:
|
| 77 |
+
t0 = time.time()
|
| 78 |
+
catalog = pd.read_parquet(config.DATA / "strain_catalog.parquet")
|
| 79 |
+
feats = pd.read_parquet(config.DATA / "features.parquet")
|
| 80 |
+
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
|
| 81 |
+
iso_in = [c for c in pheno.columns if c.startswith("isolation_cat")]
|
| 82 |
+
pheno_iso = pheno[["bacdive_id"] + iso_in].copy()
|
| 83 |
+
pheno_iso["bacdive_id"] = pheno_iso["bacdive_id"].astype(int)
|
| 84 |
+
|
| 85 |
+
catalog["bacdive_id"] = catalog["bacdive_id"].astype(int)
|
| 86 |
+
feats["bacdive_id"] = feats["bacdive_id"].astype(int)
|
| 87 |
+
catalog = catalog.merge(pheno_iso, on="bacdive_id", how="left")
|
| 88 |
+
df = catalog.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 89 |
+
df["group"] = df.apply(train03.derive_group, axis=1)
|
| 90 |
+
|
| 91 |
+
# Numeric coercion
|
| 92 |
+
for col in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
|
| 93 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 94 |
+
|
| 95 |
+
df, iso_cols = train03.encode_isolation_categories(df)
|
| 96 |
+
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
| 97 |
+
feature_cols = feature_cols + iso_cols
|
| 98 |
+
|
| 99 |
+
print(f"\nTraining table: {len(df):,} strains × {len(feature_cols)} features")
|
| 100 |
+
print("Held-out test rows are always BacDive-curated only.\n")
|
| 101 |
+
|
| 102 |
+
for target in ("optimal_ph", "salt_tolerance_pct"):
|
| 103 |
+
src_col = f"{target}_source"
|
| 104 |
+
curated = (df[src_col] == "bacdive")
|
| 105 |
+
weak = (df[src_col] == "mediadive_weak")
|
| 106 |
+
print(f"=== {target} ===")
|
| 107 |
+
print(f" curated rows: {curated.sum():,}")
|
| 108 |
+
print(f" weak rows: {weak.sum():,}")
|
| 109 |
+
|
| 110 |
+
# A) CURATED-ONLY training
|
| 111 |
+
mae_a, n_a = cv_mae(df, feature_cols, target,
|
| 112 |
+
train_mask=curated, test_mask=curated)
|
| 113 |
+
# B) CURATED + WEAK training
|
| 114 |
+
mae_b, n_b = cv_mae(df, feature_cols, target,
|
| 115 |
+
train_mask=(curated | weak), test_mask=curated)
|
| 116 |
+
delta_pct = 100 * (mae_b - mae_a) / mae_a
|
| 117 |
+
verdict = "HELPS" if mae_b < mae_a - 0.001 else (
|
| 118 |
+
"HURTS" if mae_b > mae_a + 0.001 else "WASH"
|
| 119 |
+
)
|
| 120 |
+
print(f" A. curated-only MAE = {mae_a:.4f} (eval n={n_a:,})")
|
| 121 |
+
print(f" B. curated+weak MAE = {mae_b:.4f} (eval n={n_b:,})")
|
| 122 |
+
print(f" → Δ = {delta_pct:+.1f}% [{verdict}]\n")
|
| 123 |
+
|
| 124 |
+
print(f"({time.time() - t0:.1f}s total)")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
main()
|