File size: 4,318 Bytes
4c18dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Build a unified per-strain catalog with label provenance.

Every BacDive strain in data/bacdive_phenotypes.parquet gets a row. For each of the
4 phenotype targets, two columns are emitted:
  - <target>          numeric/categorical value (or NaN if unknown)
  - <target>_source   one of: 'bacdive' | 'mediadive_weak' | 'unknown'

`bacdive` means the value came from BacDive's curated optimum / oxygen tolerance
(high-quality). `mediadive_weak` means we derived it from the median pH/NaCl% of
the DSMZ media the strain has been recorded as growing on (lower-quality fallback,
only filled in when BacDive has no value). `unknown` means we have nothing.

Saves to data/strain_catalog.parquet.
"""
from __future__ import annotations

import pandas as pd

from microbe_model import config


def main() -> None:
    pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet").copy()
    pheno["bacdive_id"] = pheno["bacdive_id"].astype(int)

    # Pull MediaDive-derived signals (per-strain median across grown media)
    md_path = config.DATA / "mediadive_features.parquet"
    md = pd.read_parquet(md_path) if md_path.exists() else pd.DataFrame()
    if len(md):
        md["bacdive_id"] = md["bacdive_id"].astype(int)

    # Build per-target value + source
    out_cols: dict[str, pd.Series] = {}
    for col in ("bacdive_id", "species", "genus", "family", "ncbi_taxon_id",
                "genome_accession", "genome_source"):
        if col in pheno.columns:
            out_cols[col] = pheno[col]

    # Direct BacDive labels (canonical). For each, fall back to MediaDive when missing.
    targets = {
        "optimal_temperature_c": None,        # no MediaDive proxy for temperature
        "optimal_ph": "md_ph_median" if len(md) else None,
        "oxygen_requirement": None,           # no media-based proxy
        "salt_tolerance_pct": "md_nacl_pct_median" if len(md) else None,
    }

    md_indexed = md.set_index("bacdive_id") if len(md) else None
    for target, md_col in targets.items():
        bacdive_vals = pheno.set_index("bacdive_id")[target] if target in pheno.columns else None
        # Reindex to match pheno order
        ordered_bacdive = (
            bacdive_vals.reindex(pheno["bacdive_id"]).values if bacdive_vals is not None else None
        )
        # MediaDive-derived fallback (numeric only; salt% capped at 30 already)
        weak_vals = (
            md_indexed[md_col].reindex(pheno["bacdive_id"]).values
            if md_col and md_indexed is not None and md_col in md_indexed.columns
            else None
        )

        values = []
        sources = []
        for i in range(len(pheno)):
            v = ordered_bacdive[i] if ordered_bacdive is not None else None
            # pandas-aware NaN check
            v_is_na = (v is None) or (isinstance(v, float) and pd.isna(v)) or (
                isinstance(v, str) and not v
            )
            if not v_is_na:
                values.append(v)
                sources.append("bacdive")
                continue
            wv = weak_vals[i] if weak_vals is not None else None
            wv_is_na = wv is None or (isinstance(wv, float) and pd.isna(wv))
            if not wv_is_na:
                values.append(wv)
                sources.append("mediadive_weak")
                continue
            values.append(None)
            sources.append("unknown")

        out_cols[target] = pd.Series(values, dtype="object")
        out_cols[f"{target}_source"] = pd.Series(sources, dtype="string")

    catalog = pd.DataFrame(out_cols)
    out = config.DATA / "strain_catalog.parquet"
    catalog.to_parquet(out, index=False)
    print(f"wrote {len(catalog):,} strains to {out}\n")

    # Summary per target
    for target in targets:
        src_col = f"{target}_source"
        counts = catalog[src_col].value_counts().to_dict()
        n_known = counts.get("bacdive", 0) + counts.get("mediadive_weak", 0)
        print(f"{target}")
        print(f"  bacdive (curated):   {counts.get('bacdive', 0):>7,}")
        print(f"  mediadive_weak:      {counts.get('mediadive_weak', 0):>7,}")
        print(f"  unknown:             {counts.get('unknown', 0):>7,}")
        print(f"  ─ any-known:         {n_known:>7,} ({100*n_known/len(catalog):.0f}%)")
        print()


if __name__ == "__main__":
    main()