File size: 3,514 Bytes
5df9ef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Build per-strain MediaDive features from strain_media + media_recipes + raw JSON.

For each BacDive strain, compute the median pH and NaCl% across all DSMZ media that
strain has been recorded as growing on. These are NOT labels — they're additional
features the model can use to predict the actual phenotype optima. Saves to
data/mediadive_features.parquet (joined into the training table by scripts/03).

Per-strain features written:
  - md_n_media:        count of media the strain grows on
  - md_ph_median:      median midpoint(min_pH, max_pH) across those media
  - md_ph_range:       max - min of medium pH across those media
  - md_nacl_pct_median:median NaCl % w/v across those media
  - md_nacl_pct_max:   max NaCl % w/v (highest tolerated)

Sanity check: where a BacDive optimum_pH or salt_tolerance_pct exists, we expect
moderate (not perfect) correlation with the corresponding MediaDive feature.
"""
from __future__ import annotations

import json
from pathlib import Path

import pandas as pd

from microbe_model import config

NACL_CAP_PCT = 30.0  # clip recipes with absurd NaCl values (parse artifacts)


def build_medium_ph_map() -> dict[str, float]:
    """Return {medium_id: midpoint pH} from raw MediaDive cache."""
    out: dict[str, float] = {}
    for path in Path(config.DATA / "mediadive").glob("*.json"):
        try:
            d = json.loads(path.read_text())
        except json.JSONDecodeError:
            continue
        if not isinstance(d, dict):
            continue
        m = d.get("medium")
        if not isinstance(m, dict):
            continue
        mid = m.get("id")
        min_ph = m.get("min_pH")
        max_ph = m.get("max_pH")
        if mid is None or min_ph is None or max_ph is None:
            continue
        try:
            out[str(mid)] = (float(min_ph) + float(max_ph)) / 2
        except (ValueError, TypeError):
            continue
    return out


def build_medium_nacl_map() -> dict[str, float]:
    """Return {medium_id: NaCl % w/v} summed from recipe compounds (clipped)."""
    mr = pd.read_parquet(config.DATA / "media_recipes.parquet")
    nacl = mr[mr["compound"].str.contains(r"sodium chlor|^nacl$", case=False, na=False, regex=True)]
    pct = (nacl.groupby("medium_id")["g_l"].sum() / 10).clip(upper=NACL_CAP_PCT)
    return pct.astype(float).to_dict()


def main() -> None:
    sm = pd.read_parquet(config.DATA / "strain_media.parquet")
    sm = sm[sm["growth"].str.lower() == "yes"].copy()
    sm["medium_id"] = sm["medium_id"].astype(str)

    ph_map = build_medium_ph_map()
    nacl_map = build_medium_nacl_map()
    print(f"medium pH map: {len(ph_map):,} media")
    print(f"medium NaCl map: {len(nacl_map):,} media")

    sm["m_ph"] = sm["medium_id"].map(ph_map)
    # Strains may grow on media not in the recipe table — treat absent as 0% NaCl
    sm["m_nacl"] = sm["medium_id"].map(nacl_map).fillna(0.0)

    # Aggregate per-strain
    grouped = sm.groupby("bacdive_id")
    feat = pd.DataFrame({
        "md_n_media": grouped.size(),
        "md_ph_median": grouped["m_ph"].median(),
        "md_ph_range": grouped["m_ph"].max() - grouped["m_ph"].min(),
        "md_nacl_pct_median": grouped["m_nacl"].median(),
        "md_nacl_pct_max": grouped["m_nacl"].max(),
    }).reset_index()

    out = config.DATA / "mediadive_features.parquet"
    feat.to_parquet(out, index=False)
    print(f"\nwrote {len(feat):,} strains to {out}")
    print(feat.describe().round(2).to_string())


if __name__ == "__main__":
    main()