Miyu Horiuchi Claude Opus 4.7 (1M context) commited on
Commit
52cf5ab
·
0 Parent(s):

Scaffold v0: BacDive + NCBI ingestion, genome feature extractor, XGBoost baseline

Browse files

Sets up the data + training pipeline for predicting cultivation conditions
(optimal T, pH, oxygen requirement, salt tolerance) from genome sequence:

- src/microbe_model/data/bacdive.py — BacDive REST client + phenotype extraction
- src/microbe_model/data/ncbi.py — NCBI Datasets v2 genome fetcher
- src/microbe_model/features/genome.py — pyrodigal CDS prediction + amino-acid
composition features (IVYWREL, hydrophobicity, isoelectric point — all
biologically motivated for the targets we predict)
- src/microbe_model/train/baseline.py — multi-task XGBoost with group K-fold
by family to prevent leakage from closely related strains
- scripts/01..04 — runnable pipeline entry points
- tests/test_features.py — smoke test on synthetic FASTA, passes

No trained model yet. Real BacDive ingestion needs BACDIVE_USER/PASSWORD env vars.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

.env.example ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # BacDive API credentials — register at https://bacdive.dsmz.de/
2
+ BACDIVE_USER=
3
+ BACDIVE_PASSWORD=
4
+
5
+ # NCBI API key — optional, raises rate limit from 3 req/s to 10 req/s
6
+ # Get one at https://www.ncbi.nlm.nih.gov/account/settings/
7
+ NCBI_API_KEY=
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.parquet filter=lfs diff=lfs merge=lfs -text
2
+ *.ubj filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.egg-info/
6
+ .pytest_cache/
7
+ .ruff_cache/
8
+ .mypy_cache/
9
+
10
+ # Virtual env
11
+ .venv/
12
+ venv/
13
+ .env
14
+ .env.local
15
+
16
+ # Data / artifacts (large files — kept out of git)
17
+ /data/
18
+ /artifacts/
19
+ /models/
20
+ *.parquet
21
+ *.fna
22
+ *.fna.gz
23
+ *.faa
24
+ *.gbff
25
+ *.gbff.gz
26
+
27
+ # Notebooks
28
+ .ipynb_checkpoints/
29
+ notebooks/scratch/
30
+
31
+ # Editor
32
+ .vscode/
33
+ .idea/
34
+ *.swp
35
+ .DS_Store
36
+
37
+ # Agent / tool state
38
+ .claude/
39
+ .letta/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # microbe-model
2
+
3
+ Predict cultivation conditions (optimal temperature, pH, oxygen requirement, salt tolerance) for
4
+ microbial isolates from genome sequence alone. The long-term aim is to lower the cost of culturing
5
+ "microbial dark matter" — the >99% of microbial diversity that has not yet been grown in pure culture.
6
+
7
+ ## Status
8
+
9
+ v0 — scaffolding the data pipeline + a non-deep-learning baseline. No trained model yet.
10
+
11
+ ## Approach
12
+
13
+ ```
14
+ BacDive (phenotype labels) ──┐
15
+ ├──> joined table (strain, genome_accession, phenotypes)
16
+ GTDB / NCBI (genomes) ───────┘
17
+
18
+
19
+ feature extraction
20
+ (genome statistics, codon usage,
21
+ proteome-level amino acid stats)
22
+
23
+
24
+ XGBoost multi-task baseline
25
+ (group K-fold by family)
26
+
27
+
28
+ eval report (MAE, F1, importances)
29
+ ```
30
+
31
+ The genome→phenotype features used here have well-established correlations with the target
32
+ properties (e.g. proteome amino acid composition correlates with optimal growth temperature),
33
+ so even a tabular model has a real signal to learn from. The point of the v0 is to establish
34
+ a ceiling before investing in transformer-based approaches.
35
+
36
+ ## Setup
37
+
38
+ ```bash
39
+ # Requires Python 3.11 and uv (https://docs.astral.sh/uv/)
40
+ uv sync --all-extras
41
+ ```
42
+
43
+ ## Running the pipeline
44
+
45
+ ```bash
46
+ # 1. Pull strain metadata + phenotype labels from BacDive
47
+ # (requires BACDIVE_USER and BACDIVE_PASSWORD env vars — register at bacdive.dsmz.de)
48
+ uv run python scripts/01_fetch_bacdive.py --limit 1000
49
+
50
+ # 2. Download genomes for strains that have an accession
51
+ uv run python scripts/02_fetch_genomes.py
52
+
53
+ # 3. Extract genome-level features (CDS prediction + amino acid stats)
54
+ uv run python scripts/03_extract_features.py
55
+
56
+ # 4. Train multi-task XGBoost baseline
57
+ uv run python scripts/04_train_baseline.py
58
+
59
+ # 5. Render eval report
60
+ uv run python scripts/05_eval.py
61
+ ```
62
+
63
+ ## Layout
64
+
65
+ ```
66
+ src/microbe_model/
67
+ config.py # paths, constants
68
+ data/
69
+ bacdive.py # BacDive REST API client
70
+ ncbi.py # NCBI genome fetcher (Datasets API)
71
+ features/
72
+ genome.py # gene prediction + tabular feature extraction
73
+ train/
74
+ baseline.py # multi-task XGBoost + group K-fold eval
75
+ scripts/ # runnable entry points (numbered by pipeline order)
76
+ tests/ # smoke tests on small fixtures
77
+ data/ # (gitignored) cached API responses, genomes, parquet tables
78
+ ```
79
+
80
+ ## What this is *not* yet
81
+
82
+ - Not a foundation model. No transformer. No genome language model.
83
+ - Not a platform. There is no upload UI or active-learning loop.
84
+ - Not validated against held-out organisms. The eval scaffolding exists; the data does not.
85
+
86
+ These are deliberate v0 boundaries. See the project notes for the longer-term plan.
87
+
88
+ ## Environment variables
89
+
90
+ Copy `.env.example` to `.env` and fill in:
91
+
92
+ - `BACDIVE_USER`, `BACDIVE_PASSWORD` — required for BacDive API access (free registration).
93
+ - `NCBI_API_KEY` — optional, raises NCBI rate limit from 3 req/s to 10 req/s.
pyproject.toml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "microbe-model"
3
+ version = "0.0.1"
4
+ description = "Predict cultivation conditions for uncultured microbes from genome sequence."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "biopython>=1.83",
9
+ "pyrodigal>=3.5",
10
+ "numpy>=1.26",
11
+ "pandas>=2.2",
12
+ "pyarrow>=15",
13
+ "scikit-learn>=1.4",
14
+ "xgboost>=2.0",
15
+ "requests>=2.32",
16
+ "tqdm>=4.66",
17
+ "python-dotenv>=1.0",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "pytest>=8.0",
23
+ "ruff>=0.4",
24
+ ]
25
+
26
+ [build-system]
27
+ requires = ["hatchling"]
28
+ build-backend = "hatchling.build"
29
+
30
+ [tool.hatch.build.targets.wheel]
31
+ packages = ["src/microbe_model"]
32
+
33
+ [tool.ruff]
34
+ line-length = 100
35
+ target-version = "py311"
36
+
37
+ [tool.ruff.lint]
38
+ select = ["E", "F", "W", "I", "UP", "B", "SIM"]
39
+ ignore = ["E501"]
40
+
41
+ [tool.pytest.ini_options]
42
+ testpaths = ["tests"]
scripts/01_fetch_bacdive.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pull strain metadata + phenotype labels from BacDive.
2
+
3
+ Writes one JSON per strain to data/bacdive/, plus a consolidated parquet table at
4
+ data/bacdive_phenotypes.parquet.
5
+
6
+ Usage:
7
+ uv run python scripts/01_fetch_bacdive.py --limit 1000
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+
13
+ import pandas as pd
14
+ from tqdm import tqdm
15
+
16
+ from microbe_model import config
17
+ from microbe_model.data.bacdive import (
18
+ BacDiveClient,
19
+ extract_phenotypes,
20
+ fetch_with_cache,
21
+ )
22
+
23
+
24
+ def main() -> None:
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--limit", type=int, default=1000, help="Max strains to fetch (None=all).")
27
+ args = parser.parse_args()
28
+
29
+ client = BacDiveClient()
30
+ rows = []
31
+ for bacdive_id in tqdm(client.iter_strain_ids(limit=args.limit), desc="BacDive", unit="strain"):
32
+ record = fetch_with_cache(client, bacdive_id)
33
+ rows.append(extract_phenotypes(record))
34
+
35
+ df = pd.DataFrame(rows)
36
+ out = config.DATA / "bacdive_phenotypes.parquet"
37
+ df.to_parquet(out, index=False)
38
+ print(f"\nWrote {len(df)} rows to {out}")
39
+ print("Coverage of prediction targets:")
40
+ for col in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
41
+ print(f" {col}: {df[col].notna().sum()} / {len(df)}")
42
+ print(f" genome_accession: {df['genome_accession'].notna().sum()} / {len(df)}")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
scripts/02_fetch_genomes.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download genome FASTAs for every BacDive strain that has an accession.
2
+
3
+ Skips strains already cached on disk. Run after 01_fetch_bacdive.py.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+
10
+ from microbe_model import config
11
+ from microbe_model.data.ncbi import GenomeNotFound, fetch_genome
12
+
13
+
14
+ def main() -> None:
15
+ table = config.DATA / "bacdive_phenotypes.parquet"
16
+ if not table.exists():
17
+ raise SystemExit(f"Missing {table}. Run scripts/01_fetch_bacdive.py first.")
18
+
19
+ df = pd.read_parquet(table)
20
+ accessions = df["genome_accession"].dropna().unique().tolist()
21
+ print(f"{len(accessions)} unique genome accessions to fetch.")
22
+
23
+ failed: list[tuple[str, str]] = []
24
+ for acc in tqdm(accessions, desc="NCBI", unit="genome"):
25
+ try:
26
+ fetch_genome(acc)
27
+ except GenomeNotFound:
28
+ failed.append((acc, "not_found"))
29
+ except Exception as exc: # noqa: BLE001 — log and continue, don't kill the batch
30
+ failed.append((acc, type(exc).__name__))
31
+
32
+ print(f"\nDownloaded: {len(accessions) - len(failed)} / {len(accessions)}")
33
+ if failed:
34
+ log = config.DATA / "genome_fetch_failures.tsv"
35
+ pd.DataFrame(failed, columns=["accession", "error"]).to_csv(log, sep="\t", index=False)
36
+ print(f"Failures logged to {log}")
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
scripts/03_extract_features.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extract tabular genome features for every cached genome.
2
+
3
+ Reads the BacDive phenotype table + the cached FASTAs in data/genomes/, runs Pyrodigal +
4
+ amino-acid-composition feature extraction, and writes data/features.parquet (one row per strain).
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+ from microbe_model import config
12
+ from microbe_model.data.ncbi import genome_path
13
+ from microbe_model.features.genome import extract_features
14
+
15
+
16
+ def main() -> None:
17
+ pheno_path = config.DATA / "bacdive_phenotypes.parquet"
18
+ if not pheno_path.exists():
19
+ raise SystemExit(f"Missing {pheno_path}. Run 01 then 02 first.")
20
+ pheno = pd.read_parquet(pheno_path)
21
+
22
+ rows = []
23
+ for _, row in tqdm(pheno.iterrows(), total=len(pheno), desc="features"):
24
+ accession = row["genome_accession"]
25
+ if not accession:
26
+ continue
27
+ path = genome_path(accession)
28
+ if not path.exists():
29
+ continue
30
+ try:
31
+ feats = extract_features(path)
32
+ except Exception as exc: # noqa: BLE001 — bad FASTA shouldn't kill the run
33
+ print(f" skip {accession}: {type(exc).__name__}: {exc}")
34
+ continue
35
+ feats["bacdive_id"] = row["bacdive_id"]
36
+ feats["genome_accession"] = accession
37
+ rows.append(feats)
38
+
39
+ feats_df = pd.DataFrame(rows)
40
+ out = config.DATA / "features.parquet"
41
+ feats_df.to_parquet(out, index=False)
42
+ print(f"\nWrote {len(feats_df)} rows to {out}")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
scripts/04_train_baseline.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train the multi-task XGBoost baseline.
2
+
3
+ Joins phenotypes + features, derives a `family` column from `species` for group K-fold,
4
+ and writes per-target metrics to artifacts/baseline_results.json.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import pandas as pd
9
+
10
+ from microbe_model import config
11
+ from microbe_model.train.baseline import save_results, train_all
12
+
13
+
14
+ def derive_family(species: str | None) -> str:
15
+ """Crude family proxy: first word of binomial. Replace with GTDB lookup later."""
16
+ if not species:
17
+ return "__unknown__"
18
+ return str(species).split()[0]
19
+
20
+
21
+ def main() -> None:
22
+ pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
23
+ feats = pd.read_parquet(config.DATA / "features.parquet")
24
+ df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
25
+ df["family"] = df["species"].apply(derive_family)
26
+
27
+ feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
28
+ print(f"Training on {len(df)} strains × {len(feature_cols)} features.")
29
+ print(f"Group counts (top 10): {df['family'].value_counts().head(10).to_dict()}")
30
+
31
+ results = train_all(df, feature_cols)
32
+
33
+ out = config.ARTIFACTS / "baseline_results.json"
34
+ save_results(results, out)
35
+ print(f"\nWrote results to {out}\n")
36
+ for target, r in results.items():
37
+ if r.folds:
38
+ metric = r.folds[0].metric_name
39
+ print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})")
40
+ else:
41
+ print(f" {target:25s} skipped (insufficient data)")
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
src/microbe_model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.0.1"
src/microbe_model/config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Project paths and shared constants."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from pathlib import Path
6
+
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ ROOT = Path(__file__).resolve().parents[2]
12
+ DATA = ROOT / "data"
13
+ ARTIFACTS = ROOT / "artifacts"
14
+
15
+ BACDIVE_DIR = DATA / "bacdive"
16
+ GENOME_DIR = DATA / "genomes"
17
+ FEATURE_DIR = DATA / "features"
18
+
19
+ for _d in (DATA, ARTIFACTS, BACDIVE_DIR, GENOME_DIR, FEATURE_DIR):
20
+ _d.mkdir(parents=True, exist_ok=True)
21
+
22
+ BACDIVE_USER = os.environ.get("BACDIVE_USER")
23
+ BACDIVE_PASSWORD = os.environ.get("BACDIVE_PASSWORD")
24
+ NCBI_API_KEY = os.environ.get("NCBI_API_KEY")
25
+
26
+ PHENOTYPE_TARGETS = {
27
+ "optimal_temperature_c": "regression",
28
+ "optimal_ph": "regression",
29
+ "oxygen_requirement": "classification",
30
+ "salt_tolerance_pct": "regression",
31
+ }
src/microbe_model/data/__init__.py ADDED
File without changes
src/microbe_model/data/bacdive.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BacDive REST API client.
2
+
3
+ BacDive (https://bacdive.dsmz.de/) is the largest curated database of bacterial phenotypes.
4
+ Free registration is required; credentials are read from BACDIVE_USER / BACDIVE_PASSWORD.
5
+
6
+ This client does the minimum needed for v0:
7
+ - log in and obtain an OAuth token
8
+ - paginate through the strain catalog
9
+ - fetch full records by BacDive ID
10
+ - extract the phenotype targets we predict (T_opt, pH_opt, oxygen, salt)
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import time
16
+ from collections.abc import Iterator
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+ import requests
21
+
22
+ from microbe_model import config
23
+
24
+ BASE_URL = "https://api.bacdive.dsmz.de"
25
+ TOKEN_URL = "https://sso.dsmz.de/auth/realms/dsmz/protocol/openid-connect/token"
26
+
27
+
28
+ class BacDiveAuthError(RuntimeError):
29
+ pass
30
+
31
+
32
+ class BacDiveClient:
33
+ def __init__(self, user: str | None = None, password: str | None = None) -> None:
34
+ self.user = user or config.BACDIVE_USER
35
+ self.password = password or config.BACDIVE_PASSWORD
36
+ if not self.user or not self.password:
37
+ raise BacDiveAuthError(
38
+ "Set BACDIVE_USER and BACDIVE_PASSWORD in .env (register at bacdive.dsmz.de)."
39
+ )
40
+ self._token: str | None = None
41
+ self._token_expires_at: float = 0.0
42
+ self._session = requests.Session()
43
+
44
+ def _refresh_token(self) -> None:
45
+ resp = self._session.post(
46
+ TOKEN_URL,
47
+ data={
48
+ "grant_type": "password",
49
+ "client_id": "api.bacdive.public",
50
+ "username": self.user,
51
+ "password": self.password,
52
+ },
53
+ timeout=30,
54
+ )
55
+ if resp.status_code != 200:
56
+ raise BacDiveAuthError(f"BacDive auth failed: {resp.status_code} {resp.text}")
57
+ body = resp.json()
58
+ self._token = body["access_token"]
59
+ self._token_expires_at = time.time() + body.get("expires_in", 300) - 30
60
+
61
+ def _headers(self) -> dict[str, str]:
62
+ if self._token is None or time.time() >= self._token_expires_at:
63
+ self._refresh_token()
64
+ return {"Authorization": f"Bearer {self._token}", "Accept": "application/json"}
65
+
66
+ def _get(self, path: str, params: dict | None = None) -> dict[str, Any]:
67
+ url = f"{BASE_URL}{path}"
68
+ for attempt in range(3):
69
+ resp = self._session.get(url, headers=self._headers(), params=params, timeout=60)
70
+ if resp.status_code == 429:
71
+ time.sleep(2 ** attempt)
72
+ continue
73
+ resp.raise_for_status()
74
+ return resp.json()
75
+ resp.raise_for_status()
76
+ return {}
77
+
78
+ def iter_strain_ids(self, limit: int | None = None) -> Iterator[int]:
79
+ """Page through the BacDive catalog and yield strain IDs."""
80
+ page_url: str | None = "/fetch/"
81
+ seen = 0
82
+ while page_url:
83
+ body = self._get(page_url)
84
+ for record in body.get("results", []):
85
+ yield int(record["id"])
86
+ seen += 1
87
+ if limit is not None and seen >= limit:
88
+ return
89
+ next_url = body.get("next")
90
+ if not next_url:
91
+ return
92
+ page_url = next_url.replace(BASE_URL, "")
93
+
94
+ def fetch_record(self, bacdive_id: int) -> dict[str, Any]:
95
+ body = self._get(f"/fetch/{bacdive_id}")
96
+ results = body.get("results") or {}
97
+ if isinstance(results, list):
98
+ return results[0] if results else {}
99
+ if isinstance(results, dict) and str(bacdive_id) in results:
100
+ return results[str(bacdive_id)]
101
+ return results
102
+
103
+
104
+ def extract_phenotypes(record: dict[str, Any]) -> dict[str, Any]:
105
+ """Pull the v0 prediction targets out of a BacDive record.
106
+
107
+ BacDive's record schema is deeply nested and field names vary across record versions.
108
+ We tolerate missing fields — anything we can't find becomes None and is dropped at training time.
109
+ """
110
+ out: dict[str, Any] = {
111
+ "bacdive_id": record.get("General", {}).get("BacDive-ID"),
112
+ "species": record.get("Name and taxonomic classification", {}).get("species"),
113
+ "ncbi_taxon_id": record.get("General", {}).get("NCBI tax id"),
114
+ "optimal_temperature_c": None,
115
+ "optimal_ph": None,
116
+ "oxygen_requirement": None,
117
+ "salt_tolerance_pct": None,
118
+ "genome_accession": None,
119
+ }
120
+
121
+ culture = record.get("Culture and growth conditions", {})
122
+ temps = _as_list(culture.get("culture temp"))
123
+ for t in temps:
124
+ if isinstance(t, dict) and t.get("type", "").lower() in {"optimum", "optimal"}:
125
+ out["optimal_temperature_c"] = _to_float(t.get("temperature"))
126
+ break
127
+
128
+ phs = _as_list(culture.get("culture pH"))
129
+ for p in phs:
130
+ if isinstance(p, dict) and p.get("type", "").lower() in {"optimum", "optimal"}:
131
+ out["optimal_ph"] = _to_float(p.get("pH"))
132
+ break
133
+
134
+ physio = record.get("Physiology and metabolism", {})
135
+ oxygen = _as_list(physio.get("oxygen tolerance"))
136
+ if oxygen and isinstance(oxygen[0], dict):
137
+ out["oxygen_requirement"] = oxygen[0].get("oxygen tolerance")
138
+
139
+ salt = _as_list(physio.get("halophily"))
140
+ for s in salt:
141
+ if isinstance(s, dict) and "concentration" in s:
142
+ out["salt_tolerance_pct"] = _to_float(s.get("concentration"))
143
+ break
144
+
145
+ seq = record.get("Sequence information", {})
146
+ genomes = _as_list(seq.get("genome sequence"))
147
+ for g in genomes:
148
+ if isinstance(g, dict) and g.get("accession"):
149
+ out["genome_accession"] = g["accession"]
150
+ break
151
+
152
+ return out
153
+
154
+
155
+ def _as_list(x: Any) -> list:
156
+ if x is None:
157
+ return []
158
+ if isinstance(x, list):
159
+ return x
160
+ return [x]
161
+
162
+
163
+ def _to_float(x: Any) -> float | None:
164
+ if x is None:
165
+ return None
166
+ try:
167
+ return float(str(x).split()[0])
168
+ except (ValueError, AttributeError):
169
+ return None
170
+
171
+
172
+ def cache_path(bacdive_id: int) -> Path:
173
+ return config.BACDIVE_DIR / f"{bacdive_id}.json"
174
+
175
+
176
+ def fetch_with_cache(client: BacDiveClient, bacdive_id: int) -> dict[str, Any]:
177
+ path = cache_path(bacdive_id)
178
+ if path.exists():
179
+ return json.loads(path.read_text())
180
+ record = client.fetch_record(bacdive_id)
181
+ path.write_text(json.dumps(record))
182
+ return record
src/microbe_model/data/ncbi.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NCBI genome fetcher.
2
+
3
+ Uses the NCBI Datasets v2 REST API to download a single nucleotide FASTA per accession.
4
+ This API doesn't require auth, but providing NCBI_API_KEY raises the rate limit from
5
+ 3 req/s to 10 req/s.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import gzip
10
+ import io
11
+ import time
12
+ import zipfile
13
+ from pathlib import Path
14
+
15
+ import requests
16
+
17
+ from microbe_model import config
18
+
19
+ DATASETS_BASE = "https://api.ncbi.nlm.nih.gov/datasets/v2"
20
+ RATE_LIMIT_S = 0.1 if config.NCBI_API_KEY else 0.34
21
+
22
+
23
+ class GenomeNotFound(RuntimeError):
24
+ pass
25
+
26
+
27
+ def genome_path(accession: str) -> Path:
28
+ return config.GENOME_DIR / f"{accession}.fna.gz"
29
+
30
+
31
+ def fetch_genome(accession: str, *, force: bool = False) -> Path:
32
+ """Download a genome FASTA for the given assembly accession (e.g. GCF_000005845.2).
33
+
34
+ The Datasets API returns a zip; we extract the FASTA, gzip it, and write to disk.
35
+ Idempotent — returns immediately if the file is already cached.
36
+ """
37
+ out = genome_path(accession)
38
+ if out.exists() and not force:
39
+ return out
40
+
41
+ url = f"{DATASETS_BASE}/genome/accession/{accession}/download"
42
+ params = {"include_annotation_type": "GENOME_FASTA"}
43
+ headers = {"Accept": "application/zip"}
44
+ if config.NCBI_API_KEY:
45
+ headers["api-key"] = config.NCBI_API_KEY
46
+
47
+ time.sleep(RATE_LIMIT_S)
48
+ resp = requests.get(url, params=params, headers=headers, timeout=120, stream=True)
49
+ if resp.status_code == 404:
50
+ raise GenomeNotFound(accession)
51
+ resp.raise_for_status()
52
+
53
+ buf = io.BytesIO(resp.content)
54
+ with zipfile.ZipFile(buf) as zf:
55
+ fasta_names = [n for n in zf.namelist() if n.endswith(".fna")]
56
+ if not fasta_names:
57
+ raise GenomeNotFound(f"{accession} (no .fna in archive)")
58
+ with zf.open(fasta_names[0]) as src, gzip.open(out, "wb") as dst:
59
+ for chunk in iter(lambda: src.read(1 << 16), b""):
60
+ dst.write(chunk)
61
+ return out
src/microbe_model/features/__init__.py ADDED
File without changes
src/microbe_model/features/genome.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tabular feature extraction from a microbial genome FASTA.
2
+
3
+ These features are deliberately simple and biologically motivated:
4
+ - genome size, GC content, coding density
5
+ - predicted gene count and mean CDS length
6
+ - proteome-level amino acid composition
7
+ - aromatic, charged, and IVYWREL fractions (correlate with growth temperature)
8
+ - mean isoelectric point and hydrophobicity
9
+
10
+ The amino-acid-composition signals have well-established correlations with optimal growth
11
+ temperature and pH (Zeldovich 2007; Tekaia 2002), so they give XGBoost real signal to learn from
12
+ without any deep model.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import gzip
17
+ from collections import Counter
18
+ from collections.abc import Iterable
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import pyrodigal
23
+ from Bio import SeqIO
24
+
25
+ AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWY"
26
+ AA_AROMATIC = set("FWY")
27
+ AA_CHARGED_POS = set("KRH")
28
+ AA_CHARGED_NEG = set("DE")
29
+ AA_IVYWREL = set("IVYWREL") # thermophile signature (Zeldovich 2007)
30
+
31
+ # Kyte-Doolittle hydrophobicity
32
+ HYDROPHOBICITY = {
33
+ "A": 1.8, "C": 2.5, "D": -3.5, "E": -3.5, "F": 2.8, "G": -0.4, "H": -3.2,
34
+ "I": 4.5, "K": -3.9, "L": 3.8, "M": 1.9, "N": -3.5, "P": -1.6, "Q": -3.5,
35
+ "R": -4.5, "S": -0.8, "T": -0.7, "V": 4.2, "W": -0.9, "Y": -1.3,
36
+ }
37
+
38
+ # pKa values for isoelectric point estimation (Lehninger)
39
+ PKA_NTERM = 9.69
40
+ PKA_CTERM = 2.34
41
+ PKA_SIDE = {"D": 3.65, "E": 4.25, "C": 8.33, "Y": 10.07, "H": 6.00, "K": 10.53, "R": 12.48}
42
+
43
+
44
+ def read_fasta_records(path: Path) -> Iterable[tuple[str, str]]:
45
+ opener = gzip.open if str(path).endswith(".gz") else open
46
+ with opener(path, "rt") as handle:
47
+ for record in SeqIO.parse(handle, "fasta"):
48
+ yield record.id, str(record.seq).upper()
49
+
50
+
51
+ def predict_proteins(contigs: Iterable[tuple[str, str]]) -> tuple[list[str], int]:
52
+ """Run Pyrodigal in meta mode and return predicted protein sequences + total nucleotides scanned."""
53
+ finder = pyrodigal.GeneFinder(meta=True)
54
+ proteins: list[str] = []
55
+ total_nt = 0
56
+ for _name, seq in contigs:
57
+ total_nt += len(seq)
58
+ # Pyrodigal accepts bytes; uppercase string works too in recent versions
59
+ genes = finder.find_genes(seq.encode("ascii"))
60
+ for gene in genes:
61
+ proteins.append(gene.translate().rstrip("*"))
62
+ return proteins, total_nt
63
+
64
+
65
+ def aa_composition(proteins: list[str]) -> dict[str, float]:
66
+ counts: Counter[str] = Counter()
67
+ total = 0
68
+ for p in proteins:
69
+ counts.update(p)
70
+ total += len(p)
71
+ if total == 0:
72
+ return {f"aa_frac_{a}": 0.0 for a in AA_ALPHABET}
73
+ return {f"aa_frac_{a}": counts.get(a, 0) / total for a in AA_ALPHABET}
74
+
75
+
76
+ def _isoelectric_point(seq: str) -> float:
77
+ """Bisection over pH to find the point where net charge is zero."""
78
+ if not seq:
79
+ return 7.0
80
+ counts = Counter(seq)
81
+ lo, hi = 0.0, 14.0
82
+ for _ in range(50):
83
+ ph = (lo + hi) / 2
84
+ net = (
85
+ 1 / (1 + 10 ** (ph - PKA_NTERM))
86
+ - 1 / (1 + 10 ** (PKA_CTERM - ph))
87
+ + counts.get("K", 0) / (1 + 10 ** (ph - PKA_SIDE["K"]))
88
+ + counts.get("R", 0) / (1 + 10 ** (ph - PKA_SIDE["R"]))
89
+ + counts.get("H", 0) / (1 + 10 ** (ph - PKA_SIDE["H"]))
90
+ - counts.get("D", 0) / (1 + 10 ** (PKA_SIDE["D"] - ph))
91
+ - counts.get("E", 0) / (1 + 10 ** (PKA_SIDE["E"] - ph))
92
+ - counts.get("C", 0) / (1 + 10 ** (PKA_SIDE["C"] - ph))
93
+ - counts.get("Y", 0) / (1 + 10 ** (PKA_SIDE["Y"] - ph))
94
+ )
95
+ if net > 0:
96
+ lo = ph
97
+ else:
98
+ hi = ph
99
+ return (lo + hi) / 2
100
+
101
+
102
+ def extract_features(fasta_path: Path) -> dict[str, float]:
103
+ contigs = list(read_fasta_records(fasta_path))
104
+ nt_total = sum(len(s) for _, s in contigs)
105
+ gc = sum(s.count("G") + s.count("C") for _, s in contigs)
106
+ gc_frac = gc / nt_total if nt_total else 0.0
107
+
108
+ proteins, _ = predict_proteins(contigs)
109
+ aa_total = sum(len(p) for p in proteins)
110
+ coding_density = (3 * aa_total) / nt_total if nt_total else 0.0
111
+
112
+ composition = aa_composition(proteins)
113
+
114
+ aromatic = sum(composition[f"aa_frac_{a}"] for a in AA_AROMATIC)
115
+ pos_charged = sum(composition[f"aa_frac_{a}"] for a in AA_CHARGED_POS)
116
+ neg_charged = sum(composition[f"aa_frac_{a}"] for a in AA_CHARGED_NEG)
117
+ ivywrel = sum(composition[f"aa_frac_{a}"] for a in AA_IVYWREL)
118
+
119
+ hydrophobicity = (
120
+ sum(composition[f"aa_frac_{a}"] * HYDROPHOBICITY[a] for a in AA_ALPHABET)
121
+ if proteins else 0.0
122
+ )
123
+
124
+ pi_values = [_isoelectric_point(p) for p in proteins[:1000]] # cap at 1k proteins for speed
125
+ mean_pi = float(np.mean(pi_values)) if pi_values else 7.0
126
+
127
+ cds_lengths = [len(p) for p in proteins]
128
+ return {
129
+ "genome_size_nt": float(nt_total),
130
+ "n_contigs": float(len(contigs)),
131
+ "gc_content": gc_frac,
132
+ "n_predicted_cds": float(len(proteins)),
133
+ "coding_density": coding_density,
134
+ "mean_cds_aa_length": float(np.mean(cds_lengths)) if cds_lengths else 0.0,
135
+ "median_cds_aa_length": float(np.median(cds_lengths)) if cds_lengths else 0.0,
136
+ "aromatic_frac": aromatic,
137
+ "pos_charged_frac": pos_charged,
138
+ "neg_charged_frac": neg_charged,
139
+ "ivywrel_frac": ivywrel,
140
+ "mean_hydrophobicity": hydrophobicity,
141
+ "mean_isoelectric_point": mean_pi,
142
+ **composition,
143
+ }
src/microbe_model/train/__init__.py ADDED
File without changes
src/microbe_model/train/baseline.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-task XGBoost baseline.
2
+
3
+ One model per phenotype target, evaluated with group K-fold by taxonomic family to prevent
4
+ leakage from closely-related strains. This is the v0 "what's the floor on tabular performance"
5
+ sanity check before we invest in transformers.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from dataclasses import dataclass, field
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import xgboost as xgb
16
+ from sklearn.metrics import f1_score, mean_absolute_error
17
+ from sklearn.model_selection import GroupKFold
18
+ from sklearn.preprocessing import LabelEncoder
19
+
20
+ from microbe_model import config
21
+
22
+
23
+ @dataclass
24
+ class FoldResult:
25
+ target: str
26
+ task: str
27
+ metric_name: str
28
+ value: float
29
+ n_train: int
30
+ n_test: int
31
+
32
+
33
+ @dataclass
34
+ class TargetResult:
35
+ target: str
36
+ task: str
37
+ folds: list[FoldResult] = field(default_factory=list)
38
+ importances: dict[str, float] = field(default_factory=dict)
39
+
40
+ def mean(self) -> float:
41
+ return float(np.mean([f.value for f in self.folds])) if self.folds else float("nan")
42
+
43
+
44
+ def _select_xy(df: pd.DataFrame, target: str, feature_cols: list[str]) -> tuple[pd.DataFrame, pd.Series]:
45
+ mask = df[target].notna()
46
+ return df.loc[mask, feature_cols], df.loc[mask, target]
47
+
48
+
49
+ def train_target(
50
+ df: pd.DataFrame,
51
+ target: str,
52
+ task: str,
53
+ feature_cols: list[str],
54
+ group_col: str = "family",
55
+ n_splits: int = 5,
56
+ ) -> TargetResult:
57
+ X, y = _select_xy(df, target, feature_cols)
58
+ groups = df.loc[X.index, group_col].fillna("__unknown__")
59
+ if len(X) < n_splits * 2:
60
+ return TargetResult(target=target, task=task)
61
+
62
+ if task == "classification":
63
+ encoder = LabelEncoder()
64
+ y_enc = encoder.fit_transform(y.astype(str))
65
+ else:
66
+ y_enc = y.to_numpy(dtype=float)
67
+
68
+ n_unique_groups = groups.nunique()
69
+ splits = min(n_splits, max(2, n_unique_groups))
70
+ kfold = GroupKFold(n_splits=splits)
71
+
72
+ result = TargetResult(target=target, task=task)
73
+ importance_acc = np.zeros(len(feature_cols), dtype=float)
74
+ fold_count = 0
75
+
76
+ for tr_idx, te_idx in kfold.split(X, y_enc, groups):
77
+ if task == "classification":
78
+ n_classes = len(np.unique(y_enc[tr_idx]))
79
+ if n_classes < 2:
80
+ continue
81
+ model = xgb.XGBClassifier(
82
+ n_estimators=300,
83
+ max_depth=5,
84
+ learning_rate=0.05,
85
+ tree_method="hist",
86
+ n_jobs=-1,
87
+ eval_metric="mlogloss",
88
+ )
89
+ model.fit(X.iloc[tr_idx], y_enc[tr_idx])
90
+ preds = model.predict(X.iloc[te_idx])
91
+ score = f1_score(y_enc[te_idx], preds, average="macro")
92
+ metric = "f1_macro"
93
+ else:
94
+ model = xgb.XGBRegressor(
95
+ n_estimators=500,
96
+ max_depth=5,
97
+ learning_rate=0.05,
98
+ tree_method="hist",
99
+ n_jobs=-1,
100
+ )
101
+ model.fit(X.iloc[tr_idx], y_enc[tr_idx])
102
+ preds = model.predict(X.iloc[te_idx])
103
+ score = mean_absolute_error(y_enc[te_idx], preds)
104
+ metric = "mae"
105
+
106
+ result.folds.append(FoldResult(
107
+ target=target,
108
+ task=task,
109
+ metric_name=metric,
110
+ value=float(score),
111
+ n_train=int(len(tr_idx)),
112
+ n_test=int(len(te_idx)),
113
+ ))
114
+ importance_acc += model.feature_importances_
115
+ fold_count += 1
116
+
117
+ if fold_count:
118
+ importance_acc /= fold_count
119
+ result.importances = dict(zip(feature_cols, importance_acc.tolist(), strict=True))
120
+ return result
121
+
122
+
123
+ def train_all(df: pd.DataFrame, feature_cols: list[str]) -> dict[str, TargetResult]:
124
+ results: dict[str, TargetResult] = {}
125
+ for target, task in config.PHENOTYPE_TARGETS.items():
126
+ if target not in df.columns:
127
+ continue
128
+ results[target] = train_target(df, target, task, feature_cols)
129
+ return results
130
+
131
+
132
+ def save_results(results: dict[str, TargetResult], path: Path) -> None:
133
+ payload = {
134
+ target: {
135
+ "task": r.task,
136
+ "mean_metric": r.mean(),
137
+ "folds": [f.__dict__ for f in r.folds],
138
+ "top_features": dict(
139
+ sorted(r.importances.items(), key=lambda kv: kv[1], reverse=True)[:20]
140
+ ),
141
+ }
142
+ for target, r in results.items()
143
+ }
144
+ path.write_text(json.dumps(payload, indent=2))
tests/__init__.py ADDED
File without changes
tests/test_features.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke test the feature extractor on a tiny synthetic genome."""
2
+ from __future__ import annotations
3
+
4
+ import gzip
5
+ from pathlib import Path
6
+
7
+ from microbe_model.features.genome import extract_features
8
+
9
+
10
+ def _write_fake_genome(path: Path) -> None:
11
+ """Write a tiny FASTA with two contigs of synthetic GC-balanced sequence."""
12
+ contigs = [
13
+ (">contig_1\n" + ("ATGCGTACGTAGCTAGCTAGCATGCGTACG" * 200) + "\n"),
14
+ (">contig_2\n" + ("CGTACGATCGATCGTACGTAGCTACGATGC" * 200) + "\n"),
15
+ ]
16
+ with gzip.open(path, "wt") as fh:
17
+ fh.write("".join(contigs))
18
+
19
+
20
+ def test_extract_features_runs(tmp_path: Path) -> None:
21
+ fasta = tmp_path / "fake.fna.gz"
22
+ _write_fake_genome(fasta)
23
+
24
+ feats = extract_features(fasta)
25
+
26
+ assert feats["genome_size_nt"] > 0
27
+ assert 0 <= feats["gc_content"] <= 1
28
+ assert feats["n_contigs"] == 2
29
+ assert feats["n_predicted_cds"] >= 0 # synthetic seq may have no real ORFs
30
+
31
+ # Amino acid fractions should sum to ~1 if any proteins were found, else 0.
32
+ aa_total = sum(v for k, v in feats.items() if k.startswith("aa_frac_"))
33
+ assert aa_total == 0.0 or abs(aa_total - 1.0) < 1e-6
34
+
35
+
36
+ def test_isoelectric_point_in_range() -> None:
37
+ from microbe_model.features.genome import _isoelectric_point
38
+
39
+ assert 0 <= _isoelectric_point("AAAAA") <= 14
40
+ assert 0 <= _isoelectric_point("DDDDD") <= 14
41
+ assert 0 <= _isoelectric_point("KKKKK") <= 14
42
+ # Acidic protein should have lower pI than basic
43
+ assert _isoelectric_point("DDDDD") < _isoelectric_point("KKKKK")
uv.lock ADDED
The diff for this file is too large to render. See raw diff