Miyu Horiuchi Claude Opus 4.7 (1M context) commited on
Commit
4c18dfd
·
1 Parent(s): 9365561

Add unified strain catalog (100K rows w/ provenance) + selective weak supervision for pH

Browse files

Three new artifacts:

scripts/21_build_strain_catalog.py
Emits data/strain_catalog.parquet — every BacDive strain (n=100,866) with each
phenotype as {value, source}. source ∈ {bacdive, mediadive_weak, unknown}.
Coverage: T_opt 50% (BacDive only), pH 29% (5,794 BacDive + 23,574 MediaDive),
oxygen 23%, salt 30% (4,242 BacDive + 26,055 MediaDive). 50K strains have an
explicitly 'unknown' temperature, the largest single bucket.

scripts/23_weak_label_apples_to_apples.py + artifacts/weak_label_test.log
Honest semi-supervised test: train on (curated + MediaDive-weak), evaluate on
held-out *curated* test rows only — does weak supervision help generalization
to the gold-standard distribution?
optimal_ph curated 0.5133 → curated+weak 0.4934 (-3.9%, HELPS)
salt_tolerance_pct curated 2.1060 → curated+weak 2.1859 (+3.8%, HURTS)
Matches the pre-experiment correlation probe: pH↔MediaDive corr 0.62 helps,
salt↔MediaDive corr 0.42 hurts.

scripts/15_train_phenotype_heads.py
Backfills pH labels from MediaDive (5,103 → 26,181) before training the deployed
pH quantile heads. Salt deliberately stays curated-only. Re-saved .ubj heads.

scripts/22_train_with_weak_labels.py kept for reproducibility (full-table weak
training; superseded by scripts/23 for the rigorous comparison).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

models/phenotype/optimal_ph_q10.ubj CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ee3dc1e9911b70ea915af55dcfdf5cf73b815d6a454ae8279dbac2142c7b596
3
- size 827905
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16498ff43324f45392404520ef74f6e29f3f40572b57662247f4c03f46c2401d
3
+ size 937045
models/phenotype/optimal_ph_q50.ubj CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:99652ebd8c1d5bbdaefd59d4724e15f6126cd56313dded2af8f1c32fc14ead66
3
- size 892163
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e9682110e2d5ee8e1e8bd36c9dc48911566f946df2a2b8b656d658290e471bf
3
+ size 1058561
models/phenotype/optimal_ph_q90.ubj CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:136173f080b8ea4b0bab00c23a340c0c5da4eb0dbbba9032701e12fc1fdd91cd
3
- size 833683
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c680f7a908fa9ffe63f63c376a734799f5df5fd6a37baf0de52900d2e95a3cbc
3
+ size 904405
scripts/15_train_phenotype_heads.py CHANGED
@@ -38,6 +38,23 @@ def main() -> None:
38
  df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
39
  feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  out_dir = config.ROOT / "models" / "phenotype"
42
  out_dir.mkdir(parents=True, exist_ok=True)
43
  (out_dir / "feature_cols.json").write_text(json.dumps(feature_cols))
 
38
  df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
39
  feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
40
 
41
+ # Backfill pH with MediaDive-derived weak labels where BacDive has none.
42
+ # Apples-to-apples test (scripts/23) showed this nets -3.9% MAE on held-out
43
+ # curated test rows (corr 0.62 with optima). Salt did NOT pass the same test
44
+ # (+3.8% MAE, corr only 0.42), so we deliberately don't backfill salt.
45
+ catalog_path = config.DATA / "strain_catalog.parquet"
46
+ if catalog_path.exists():
47
+ catalog = pd.read_parquet(catalog_path)[["bacdive_id", "optimal_ph", "optimal_ph_source"]]
48
+ catalog["bacdive_id"] = catalog["bacdive_id"].astype(int)
49
+ catalog["optimal_ph"] = pd.to_numeric(catalog["optimal_ph"], errors="coerce")
50
+ df["bacdive_id"] = df["bacdive_id"].astype(int)
51
+ df = df.merge(catalog, on="bacdive_id", how="left", suffixes=("", "_cat"))
52
+ n_before = df["optimal_ph"].notna().sum()
53
+ ph_missing = df["optimal_ph"].isna() & df["optimal_ph_cat"].notna() & df["optimal_ph_source"].eq("mediadive_weak")
54
+ df.loc[ph_missing, "optimal_ph"] = df.loc[ph_missing, "optimal_ph_cat"]
55
+ n_after = df["optimal_ph"].notna().sum()
56
+ print(f"pH labels: {n_before:,} curated → {n_after:,} after MediaDive backfill (+{n_after - n_before:,})")
57
+
58
  out_dir = config.ROOT / "models" / "phenotype"
59
  out_dir.mkdir(parents=True, exist_ok=True)
60
  (out_dir / "feature_cols.json").write_text(json.dumps(feature_cols))
scripts/21_build_strain_catalog.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a unified per-strain catalog with label provenance.
2
+
3
+ Every BacDive strain in data/bacdive_phenotypes.parquet gets a row. For each of the
4
+ 4 phenotype targets, two columns are emitted:
5
+ - <target> numeric/categorical value (or NaN if unknown)
6
+ - <target>_source one of: 'bacdive' | 'mediadive_weak' | 'unknown'
7
+
8
+ `bacdive` means the value came from BacDive's curated optimum / oxygen tolerance
9
+ (high-quality). `mediadive_weak` means we derived it from the median pH/NaCl% of
10
+ the DSMZ media the strain has been recorded as growing on (lower-quality fallback,
11
+ only filled in when BacDive has no value). `unknown` means we have nothing.
12
+
13
+ Saves to data/strain_catalog.parquet.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import pandas as pd
18
+
19
+ from microbe_model import config
20
+
21
+
22
+ def main() -> None:
23
+ pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet").copy()
24
+ pheno["bacdive_id"] = pheno["bacdive_id"].astype(int)
25
+
26
+ # Pull MediaDive-derived signals (per-strain median across grown media)
27
+ md_path = config.DATA / "mediadive_features.parquet"
28
+ md = pd.read_parquet(md_path) if md_path.exists() else pd.DataFrame()
29
+ if len(md):
30
+ md["bacdive_id"] = md["bacdive_id"].astype(int)
31
+
32
+ # Build per-target value + source
33
+ out_cols: dict[str, pd.Series] = {}
34
+ for col in ("bacdive_id", "species", "genus", "family", "ncbi_taxon_id",
35
+ "genome_accession", "genome_source"):
36
+ if col in pheno.columns:
37
+ out_cols[col] = pheno[col]
38
+
39
+ # Direct BacDive labels (canonical). For each, fall back to MediaDive when missing.
40
+ targets = {
41
+ "optimal_temperature_c": None, # no MediaDive proxy for temperature
42
+ "optimal_ph": "md_ph_median" if len(md) else None,
43
+ "oxygen_requirement": None, # no media-based proxy
44
+ "salt_tolerance_pct": "md_nacl_pct_median" if len(md) else None,
45
+ }
46
+
47
+ md_indexed = md.set_index("bacdive_id") if len(md) else None
48
+ for target, md_col in targets.items():
49
+ bacdive_vals = pheno.set_index("bacdive_id")[target] if target in pheno.columns else None
50
+ # Reindex to match pheno order
51
+ ordered_bacdive = (
52
+ bacdive_vals.reindex(pheno["bacdive_id"]).values if bacdive_vals is not None else None
53
+ )
54
+ # MediaDive-derived fallback (numeric only; salt% capped at 30 already)
55
+ weak_vals = (
56
+ md_indexed[md_col].reindex(pheno["bacdive_id"]).values
57
+ if md_col and md_indexed is not None and md_col in md_indexed.columns
58
+ else None
59
+ )
60
+
61
+ values = []
62
+ sources = []
63
+ for i in range(len(pheno)):
64
+ v = ordered_bacdive[i] if ordered_bacdive is not None else None
65
+ # pandas-aware NaN check
66
+ v_is_na = (v is None) or (isinstance(v, float) and pd.isna(v)) or (
67
+ isinstance(v, str) and not v
68
+ )
69
+ if not v_is_na:
70
+ values.append(v)
71
+ sources.append("bacdive")
72
+ continue
73
+ wv = weak_vals[i] if weak_vals is not None else None
74
+ wv_is_na = wv is None or (isinstance(wv, float) and pd.isna(wv))
75
+ if not wv_is_na:
76
+ values.append(wv)
77
+ sources.append("mediadive_weak")
78
+ continue
79
+ values.append(None)
80
+ sources.append("unknown")
81
+
82
+ out_cols[target] = pd.Series(values, dtype="object")
83
+ out_cols[f"{target}_source"] = pd.Series(sources, dtype="string")
84
+
85
+ catalog = pd.DataFrame(out_cols)
86
+ out = config.DATA / "strain_catalog.parquet"
87
+ catalog.to_parquet(out, index=False)
88
+ print(f"wrote {len(catalog):,} strains to {out}\n")
89
+
90
+ # Summary per target
91
+ for target in targets:
92
+ src_col = f"{target}_source"
93
+ counts = catalog[src_col].value_counts().to_dict()
94
+ n_known = counts.get("bacdive", 0) + counts.get("mediadive_weak", 0)
95
+ print(f"{target}")
96
+ print(f" bacdive (curated): {counts.get('bacdive', 0):>7,}")
97
+ print(f" mediadive_weak: {counts.get('mediadive_weak', 0):>7,}")
98
+ print(f" unknown: {counts.get('unknown', 0):>7,}")
99
+ print(f" ─ any-known: {n_known:>7,} ({100*n_known/len(catalog):.0f}%)")
100
+ print()
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
scripts/22_train_with_weak_labels.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train baseline using BacDive-first + MediaDive-weak fallback labels for pH/salt.
2
+
3
+ Compares vs the curated-only baseline (artifacts/baseline_results.json) to see whether
4
+ the weak labels are net-helpful for the model. T_opt and oxygen are unaffected (no
5
+ weak source). pH and salt get many more training rows but with noisier labels.
6
+
7
+ Output: artifacts/baseline_results_weak.json
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import time
12
+
13
+ import pandas as pd
14
+
15
+ from microbe_model import config
16
+ from microbe_model.train.baseline import save_results, train_all
17
+
18
+ # Reuse the encoders from scripts/03 — copy locally to avoid sys.path gymnastics
19
+ import importlib.util
20
+ spec = importlib.util.spec_from_file_location("train03", config.ROOT / "scripts" / "03_train_baseline.py")
21
+ train03 = importlib.util.module_from_spec(spec)
22
+ spec.loader.exec_module(train03)
23
+
24
+
25
+ def main() -> None:
26
+ t0 = time.time()
27
+ catalog = pd.read_parquet(config.DATA / "strain_catalog.parquet")
28
+ feats = pd.read_parquet(config.DATA / "features.parquet")
29
+ pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
30
+ # Isolation categories live in the original phenotype table — graft them onto the catalog
31
+ iso_in = [c for c in pheno.columns if c.startswith("isolation_cat")]
32
+ pheno_iso = pheno[["bacdive_id"] + iso_in].copy()
33
+ pheno_iso["bacdive_id"] = pheno_iso["bacdive_id"].astype(int)
34
+
35
+ catalog["bacdive_id"] = catalog["bacdive_id"].astype(int)
36
+ feats["bacdive_id"] = feats["bacdive_id"].astype(int)
37
+ catalog = catalog.merge(pheno_iso, on="bacdive_id", how="left")
38
+ df = catalog.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
39
+ df["group"] = df.apply(train03.derive_group, axis=1)
40
+
41
+ # Coerce numeric targets (catalog stored as object)
42
+ for col in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
43
+ if col in df.columns:
44
+ df[col] = pd.to_numeric(df[col], errors="coerce")
45
+
46
+ df, iso_cols = train03.encode_isolation_categories(df)
47
+ print(f"Encoded {len(iso_cols)} isolation features")
48
+
49
+ # IMPORTANT: do NOT add MediaDive features here. The weak labels for pH and salt
50
+ # are *derived from those same features* (per-strain median across DSMZ media), so
51
+ # including them as inputs leaks the target — the model trivially predicts the
52
+ # matching feature column. The honest test is: do MediaDive-derived weak labels
53
+ # help a genome+isolation model generalize?
54
+ feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
55
+ feature_cols = feature_cols + iso_cols
56
+
57
+ print(f"\nTraining table: {len(df):,} strains × {len(feature_cols)} features")
58
+ for tgt in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
59
+ n = df[tgt].notna().sum() if tgt in df.columns else 0
60
+ print(f" {tgt:25s} labeled={n:>6,}")
61
+ print()
62
+
63
+ results = train_all(df, feature_cols, group_col_override="group")
64
+ out = config.ARTIFACTS / "baseline_results_weak.json"
65
+ save_results(results, out, predictions_path=None, feature_cols=feature_cols)
66
+
67
+ print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
68
+ for target, r in results.items():
69
+ if r.folds:
70
+ print(f" {target:25s} {r.folds[0].metric_name:10s} = {r.mean():.4f}")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
scripts/23_weak_label_apples_to_apples.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Honest test: do MediaDive-derived weak labels help generalize to BacDive-curated optima?
2
+
3
+ For pH and salt, two training regimes — same held-out test rows (curated only):
4
+ A. CURATED-ONLY: train on BacDive curated labels.
5
+ B. CURATED + WEAK: train on BacDive curated + MediaDive-derived weak labels.
6
+
7
+ In both, the test set per fold is the *intersection* of the held-out group with the
8
+ curated subset. This isolates whether weak labels help the model do better on the
9
+ gold-standard distribution, rather than just helping it predict the medium pH/salt
10
+ of the strain itself (which would be circular).
11
+
12
+ No MediaDive features are used (deployed model parity).
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import importlib.util
17
+ import sys
18
+ import time
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import xgboost as xgb
24
+ from sklearn.metrics import mean_absolute_error
25
+ from sklearn.model_selection import GroupKFold
26
+
27
+ from microbe_model import config
28
+
29
+ ROOT = Path(__file__).resolve().parent.parent
30
+ sys.path.insert(0, str(ROOT / "scripts"))
31
+
32
+ spec = importlib.util.spec_from_file_location("train03", ROOT / "scripts" / "03_train_baseline.py")
33
+ train03 = importlib.util.module_from_spec(spec)
34
+ spec.loader.exec_module(train03)
35
+
36
+
37
+ def cv_mae(
38
+ df: pd.DataFrame,
39
+ feature_cols: list[str],
40
+ target: str,
41
+ *,
42
+ train_mask: pd.Series,
43
+ test_mask: pd.Series,
44
+ n_splits: int = 5,
45
+ ) -> tuple[float, int]:
46
+ """5-fold GroupKFold by family. Train on (train_mask & target.notna() & not test fold).
47
+ Evaluate on (test_mask & target.notna() & test fold). Returns mean MAE across folds.
48
+ """
49
+ eligible = df[df[target].notna()].copy()
50
+ eligible[target] = pd.to_numeric(eligible[target], errors="coerce")
51
+ eligible = eligible[eligible[target].notna()]
52
+ groups = eligible["group"].fillna("__unknown__")
53
+ splits = min(n_splits, max(2, groups.nunique()))
54
+ kf = GroupKFold(n_splits=splits)
55
+ maes = []
56
+ n_eval_total = 0
57
+ for tr_idx, te_idx in kf.split(eligible, eligible[target], groups):
58
+ tr = eligible.iloc[tr_idx]
59
+ te = eligible.iloc[te_idx]
60
+ # Apply masks to the row indices we're using
61
+ tr = tr[train_mask.reindex(tr.index, fill_value=False).values]
62
+ te = te[test_mask.reindex(te.index, fill_value=False).values]
63
+ if len(tr) < 100 or len(te) < 50:
64
+ continue
65
+ m = xgb.XGBRegressor(
66
+ n_estimators=400, max_depth=5, learning_rate=0.05,
67
+ tree_method="hist", n_jobs=-1,
68
+ )
69
+ m.fit(tr[feature_cols], tr[target].astype(float))
70
+ preds = m.predict(te[feature_cols])
71
+ maes.append(mean_absolute_error(te[target].astype(float), preds))
72
+ n_eval_total += len(te)
73
+ return (float(np.mean(maes)) if maes else float("nan")), n_eval_total
74
+
75
+
76
+ def main() -> None:
77
+ t0 = time.time()
78
+ catalog = pd.read_parquet(config.DATA / "strain_catalog.parquet")
79
+ feats = pd.read_parquet(config.DATA / "features.parquet")
80
+ pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
81
+ iso_in = [c for c in pheno.columns if c.startswith("isolation_cat")]
82
+ pheno_iso = pheno[["bacdive_id"] + iso_in].copy()
83
+ pheno_iso["bacdive_id"] = pheno_iso["bacdive_id"].astype(int)
84
+
85
+ catalog["bacdive_id"] = catalog["bacdive_id"].astype(int)
86
+ feats["bacdive_id"] = feats["bacdive_id"].astype(int)
87
+ catalog = catalog.merge(pheno_iso, on="bacdive_id", how="left")
88
+ df = catalog.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
89
+ df["group"] = df.apply(train03.derive_group, axis=1)
90
+
91
+ # Numeric coercion
92
+ for col in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
93
+ df[col] = pd.to_numeric(df[col], errors="coerce")
94
+
95
+ df, iso_cols = train03.encode_isolation_categories(df)
96
+ feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
97
+ feature_cols = feature_cols + iso_cols
98
+
99
+ print(f"\nTraining table: {len(df):,} strains × {len(feature_cols)} features")
100
+ print("Held-out test rows are always BacDive-curated only.\n")
101
+
102
+ for target in ("optimal_ph", "salt_tolerance_pct"):
103
+ src_col = f"{target}_source"
104
+ curated = (df[src_col] == "bacdive")
105
+ weak = (df[src_col] == "mediadive_weak")
106
+ print(f"=== {target} ===")
107
+ print(f" curated rows: {curated.sum():,}")
108
+ print(f" weak rows: {weak.sum():,}")
109
+
110
+ # A) CURATED-ONLY training
111
+ mae_a, n_a = cv_mae(df, feature_cols, target,
112
+ train_mask=curated, test_mask=curated)
113
+ # B) CURATED + WEAK training
114
+ mae_b, n_b = cv_mae(df, feature_cols, target,
115
+ train_mask=(curated | weak), test_mask=curated)
116
+ delta_pct = 100 * (mae_b - mae_a) / mae_a
117
+ verdict = "HELPS" if mae_b < mae_a - 0.001 else (
118
+ "HURTS" if mae_b > mae_a + 0.001 else "WASH"
119
+ )
120
+ print(f" A. curated-only MAE = {mae_a:.4f} (eval n={n_a:,})")
121
+ print(f" B. curated+weak MAE = {mae_b:.4f} (eval n={n_b:,})")
122
+ print(f" → Δ = {delta_pct:+.1f}% [{verdict}]\n")
123
+
124
+ print(f"({time.time() - t0:.1f}s total)")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()