Spaces:
Running
Scaffold v0: BacDive + NCBI ingestion, genome feature extractor, XGBoost baseline
Browse filesSets up the data + training pipeline for predicting cultivation conditions
(optimal T, pH, oxygen requirement, salt tolerance) from genome sequence:
- src/microbe_model/data/bacdive.py — BacDive REST client + phenotype extraction
- src/microbe_model/data/ncbi.py — NCBI Datasets v2 genome fetcher
- src/microbe_model/features/genome.py — pyrodigal CDS prediction + amino-acid
composition features (IVYWREL, hydrophobicity, isoelectric point — all
biologically motivated for the targets we predict)
- src/microbe_model/train/baseline.py — multi-task XGBoost with group K-fold
by family to prevent leakage from closely related strains
- scripts/01..04 — runnable pipeline entry points
- tests/test_features.py — smoke test on synthetic FASTA, passes
No trained model yet. Real BacDive ingestion needs BACDIVE_USER/PASSWORD env vars.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- .env.example +7 -0
- .gitattributes +2 -0
- .gitignore +39 -0
- .python-version +1 -0
- README.md +93 -0
- pyproject.toml +42 -0
- scripts/01_fetch_bacdive.py +46 -0
- scripts/02_fetch_genomes.py +40 -0
- scripts/03_extract_features.py +46 -0
- scripts/04_train_baseline.py +45 -0
- src/microbe_model/__init__.py +1 -0
- src/microbe_model/config.py +31 -0
- src/microbe_model/data/__init__.py +0 -0
- src/microbe_model/data/bacdive.py +182 -0
- src/microbe_model/data/ncbi.py +61 -0
- src/microbe_model/features/__init__.py +0 -0
- src/microbe_model/features/genome.py +143 -0
- src/microbe_model/train/__init__.py +0 -0
- src/microbe_model/train/baseline.py +144 -0
- tests/__init__.py +0 -0
- tests/test_features.py +43 -0
- uv.lock +0 -0
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BacDive API credentials — register at https://bacdive.dsmz.de/
|
| 2 |
+
BACDIVE_USER=
|
| 3 |
+
BACDIVE_PASSWORD=
|
| 4 |
+
|
| 5 |
+
# NCBI API key — optional, raises rate limit from 3 req/s to 10 req/s
|
| 6 |
+
# Get one at https://www.ncbi.nlm.nih.gov/account/settings/
|
| 7 |
+
NCBI_API_KEY=
|
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.ubj filter=lfs diff=lfs merge=lfs -text
|
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.egg-info/
|
| 6 |
+
.pytest_cache/
|
| 7 |
+
.ruff_cache/
|
| 8 |
+
.mypy_cache/
|
| 9 |
+
|
| 10 |
+
# Virtual env
|
| 11 |
+
.venv/
|
| 12 |
+
venv/
|
| 13 |
+
.env
|
| 14 |
+
.env.local
|
| 15 |
+
|
| 16 |
+
# Data / artifacts (large files — kept out of git)
|
| 17 |
+
/data/
|
| 18 |
+
/artifacts/
|
| 19 |
+
/models/
|
| 20 |
+
*.parquet
|
| 21 |
+
*.fna
|
| 22 |
+
*.fna.gz
|
| 23 |
+
*.faa
|
| 24 |
+
*.gbff
|
| 25 |
+
*.gbff.gz
|
| 26 |
+
|
| 27 |
+
# Notebooks
|
| 28 |
+
.ipynb_checkpoints/
|
| 29 |
+
notebooks/scratch/
|
| 30 |
+
|
| 31 |
+
# Editor
|
| 32 |
+
.vscode/
|
| 33 |
+
.idea/
|
| 34 |
+
*.swp
|
| 35 |
+
.DS_Store
|
| 36 |
+
|
| 37 |
+
# Agent / tool state
|
| 38 |
+
.claude/
|
| 39 |
+
.letta/
|
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# microbe-model
|
| 2 |
+
|
| 3 |
+
Predict cultivation conditions (optimal temperature, pH, oxygen requirement, salt tolerance) for
|
| 4 |
+
microbial isolates from genome sequence alone. The long-term aim is to lower the cost of culturing
|
| 5 |
+
"microbial dark matter" — the >99% of microbial diversity that has not yet been grown in pure culture.
|
| 6 |
+
|
| 7 |
+
## Status
|
| 8 |
+
|
| 9 |
+
v0 — scaffolding the data pipeline + a non-deep-learning baseline. No trained model yet.
|
| 10 |
+
|
| 11 |
+
## Approach
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
BacDive (phenotype labels) ──┐
|
| 15 |
+
├──> joined table (strain, genome_accession, phenotypes)
|
| 16 |
+
GTDB / NCBI (genomes) ───────┘
|
| 17 |
+
│
|
| 18 |
+
▼
|
| 19 |
+
feature extraction
|
| 20 |
+
(genome statistics, codon usage,
|
| 21 |
+
proteome-level amino acid stats)
|
| 22 |
+
│
|
| 23 |
+
▼
|
| 24 |
+
XGBoost multi-task baseline
|
| 25 |
+
(group K-fold by family)
|
| 26 |
+
│
|
| 27 |
+
▼
|
| 28 |
+
eval report (MAE, F1, importances)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
The genome→phenotype features used here have well-established correlations with the target
|
| 32 |
+
properties (e.g. proteome amino acid composition correlates with optimal growth temperature),
|
| 33 |
+
so even a tabular model has a real signal to learn from. The point of the v0 is to establish
|
| 34 |
+
a ceiling before investing in transformer-based approaches.
|
| 35 |
+
|
| 36 |
+
## Setup
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# Requires Python 3.11 and uv (https://docs.astral.sh/uv/)
|
| 40 |
+
uv sync --all-extras
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Running the pipeline
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# 1. Pull strain metadata + phenotype labels from BacDive
|
| 47 |
+
# (requires BACDIVE_USER and BACDIVE_PASSWORD env vars — register at bacdive.dsmz.de)
|
| 48 |
+
uv run python scripts/01_fetch_bacdive.py --limit 1000
|
| 49 |
+
|
| 50 |
+
# 2. Download genomes for strains that have an accession
|
| 51 |
+
uv run python scripts/02_fetch_genomes.py
|
| 52 |
+
|
| 53 |
+
# 3. Extract genome-level features (CDS prediction + amino acid stats)
|
| 54 |
+
uv run python scripts/03_extract_features.py
|
| 55 |
+
|
| 56 |
+
# 4. Train multi-task XGBoost baseline
|
| 57 |
+
uv run python scripts/04_train_baseline.py
|
| 58 |
+
|
| 59 |
+
# 5. Render eval report
|
| 60 |
+
uv run python scripts/05_eval.py
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Layout
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
src/microbe_model/
|
| 67 |
+
config.py # paths, constants
|
| 68 |
+
data/
|
| 69 |
+
bacdive.py # BacDive REST API client
|
| 70 |
+
ncbi.py # NCBI genome fetcher (Datasets API)
|
| 71 |
+
features/
|
| 72 |
+
genome.py # gene prediction + tabular feature extraction
|
| 73 |
+
train/
|
| 74 |
+
baseline.py # multi-task XGBoost + group K-fold eval
|
| 75 |
+
scripts/ # runnable entry points (numbered by pipeline order)
|
| 76 |
+
tests/ # smoke tests on small fixtures
|
| 77 |
+
data/ # (gitignored) cached API responses, genomes, parquet tables
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## What this is *not* yet
|
| 81 |
+
|
| 82 |
+
- Not a foundation model. No transformer. No genome language model.
|
| 83 |
+
- Not a platform. There is no upload UI or active-learning loop.
|
| 84 |
+
- Not validated against held-out organisms. The eval scaffolding exists; the data does not.
|
| 85 |
+
|
| 86 |
+
These are deliberate v0 boundaries. See the project notes for the longer-term plan.
|
| 87 |
+
|
| 88 |
+
## Environment variables
|
| 89 |
+
|
| 90 |
+
Copy `.env.example` to `.env` and fill in:
|
| 91 |
+
|
| 92 |
+
- `BACDIVE_USER`, `BACDIVE_PASSWORD` — required for BacDive API access (free registration).
|
| 93 |
+
- `NCBI_API_KEY` — optional, raises NCBI rate limit from 3 req/s to 10 req/s.
|
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "microbe-model"
|
| 3 |
+
version = "0.0.1"
|
| 4 |
+
description = "Predict cultivation conditions for uncultured microbes from genome sequence."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"biopython>=1.83",
|
| 9 |
+
"pyrodigal>=3.5",
|
| 10 |
+
"numpy>=1.26",
|
| 11 |
+
"pandas>=2.2",
|
| 12 |
+
"pyarrow>=15",
|
| 13 |
+
"scikit-learn>=1.4",
|
| 14 |
+
"xgboost>=2.0",
|
| 15 |
+
"requests>=2.32",
|
| 16 |
+
"tqdm>=4.66",
|
| 17 |
+
"python-dotenv>=1.0",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[project.optional-dependencies]
|
| 21 |
+
dev = [
|
| 22 |
+
"pytest>=8.0",
|
| 23 |
+
"ruff>=0.4",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[build-system]
|
| 27 |
+
requires = ["hatchling"]
|
| 28 |
+
build-backend = "hatchling.build"
|
| 29 |
+
|
| 30 |
+
[tool.hatch.build.targets.wheel]
|
| 31 |
+
packages = ["src/microbe_model"]
|
| 32 |
+
|
| 33 |
+
[tool.ruff]
|
| 34 |
+
line-length = 100
|
| 35 |
+
target-version = "py311"
|
| 36 |
+
|
| 37 |
+
[tool.ruff.lint]
|
| 38 |
+
select = ["E", "F", "W", "I", "UP", "B", "SIM"]
|
| 39 |
+
ignore = ["E501"]
|
| 40 |
+
|
| 41 |
+
[tool.pytest.ini_options]
|
| 42 |
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pull strain metadata + phenotype labels from BacDive.
|
| 2 |
+
|
| 3 |
+
Writes one JSON per strain to data/bacdive/, plus a consolidated parquet table at
|
| 4 |
+
data/bacdive_phenotypes.parquet.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
uv run python scripts/01_fetch_bacdive.py --limit 1000
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from microbe_model import config
|
| 17 |
+
from microbe_model.data.bacdive import (
|
| 18 |
+
BacDiveClient,
|
| 19 |
+
extract_phenotypes,
|
| 20 |
+
fetch_with_cache,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> None:
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--limit", type=int, default=1000, help="Max strains to fetch (None=all).")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
client = BacDiveClient()
|
| 30 |
+
rows = []
|
| 31 |
+
for bacdive_id in tqdm(client.iter_strain_ids(limit=args.limit), desc="BacDive", unit="strain"):
|
| 32 |
+
record = fetch_with_cache(client, bacdive_id)
|
| 33 |
+
rows.append(extract_phenotypes(record))
|
| 34 |
+
|
| 35 |
+
df = pd.DataFrame(rows)
|
| 36 |
+
out = config.DATA / "bacdive_phenotypes.parquet"
|
| 37 |
+
df.to_parquet(out, index=False)
|
| 38 |
+
print(f"\nWrote {len(df)} rows to {out}")
|
| 39 |
+
print("Coverage of prediction targets:")
|
| 40 |
+
for col in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
|
| 41 |
+
print(f" {col}: {df[col].notna().sum()} / {len(df)}")
|
| 42 |
+
print(f" genome_accession: {df['genome_accession'].notna().sum()} / {len(df)}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download genome FASTAs for every BacDive strain that has an accession.
|
| 2 |
+
|
| 3 |
+
Skips strains already cached on disk. Run after 01_fetch_bacdive.py.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from microbe_model import config
|
| 11 |
+
from microbe_model.data.ncbi import GenomeNotFound, fetch_genome
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main() -> None:
|
| 15 |
+
table = config.DATA / "bacdive_phenotypes.parquet"
|
| 16 |
+
if not table.exists():
|
| 17 |
+
raise SystemExit(f"Missing {table}. Run scripts/01_fetch_bacdive.py first.")
|
| 18 |
+
|
| 19 |
+
df = pd.read_parquet(table)
|
| 20 |
+
accessions = df["genome_accession"].dropna().unique().tolist()
|
| 21 |
+
print(f"{len(accessions)} unique genome accessions to fetch.")
|
| 22 |
+
|
| 23 |
+
failed: list[tuple[str, str]] = []
|
| 24 |
+
for acc in tqdm(accessions, desc="NCBI", unit="genome"):
|
| 25 |
+
try:
|
| 26 |
+
fetch_genome(acc)
|
| 27 |
+
except GenomeNotFound:
|
| 28 |
+
failed.append((acc, "not_found"))
|
| 29 |
+
except Exception as exc: # noqa: BLE001 — log and continue, don't kill the batch
|
| 30 |
+
failed.append((acc, type(exc).__name__))
|
| 31 |
+
|
| 32 |
+
print(f"\nDownloaded: {len(accessions) - len(failed)} / {len(accessions)}")
|
| 33 |
+
if failed:
|
| 34 |
+
log = config.DATA / "genome_fetch_failures.tsv"
|
| 35 |
+
pd.DataFrame(failed, columns=["accession", "error"]).to_csv(log, sep="\t", index=False)
|
| 36 |
+
print(f"Failures logged to {log}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Extract tabular genome features for every cached genome.
|
| 2 |
+
|
| 3 |
+
Reads the BacDive phenotype table + the cached FASTAs in data/genomes/, runs Pyrodigal +
|
| 4 |
+
amino-acid-composition feature extraction, and writes data/features.parquet (one row per strain).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from microbe_model import config
|
| 12 |
+
from microbe_model.data.ncbi import genome_path
|
| 13 |
+
from microbe_model.features.genome import extract_features
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main() -> None:
|
| 17 |
+
pheno_path = config.DATA / "bacdive_phenotypes.parquet"
|
| 18 |
+
if not pheno_path.exists():
|
| 19 |
+
raise SystemExit(f"Missing {pheno_path}. Run 01 then 02 first.")
|
| 20 |
+
pheno = pd.read_parquet(pheno_path)
|
| 21 |
+
|
| 22 |
+
rows = []
|
| 23 |
+
for _, row in tqdm(pheno.iterrows(), total=len(pheno), desc="features"):
|
| 24 |
+
accession = row["genome_accession"]
|
| 25 |
+
if not accession:
|
| 26 |
+
continue
|
| 27 |
+
path = genome_path(accession)
|
| 28 |
+
if not path.exists():
|
| 29 |
+
continue
|
| 30 |
+
try:
|
| 31 |
+
feats = extract_features(path)
|
| 32 |
+
except Exception as exc: # noqa: BLE001 — bad FASTA shouldn't kill the run
|
| 33 |
+
print(f" skip {accession}: {type(exc).__name__}: {exc}")
|
| 34 |
+
continue
|
| 35 |
+
feats["bacdive_id"] = row["bacdive_id"]
|
| 36 |
+
feats["genome_accession"] = accession
|
| 37 |
+
rows.append(feats)
|
| 38 |
+
|
| 39 |
+
feats_df = pd.DataFrame(rows)
|
| 40 |
+
out = config.DATA / "features.parquet"
|
| 41 |
+
feats_df.to_parquet(out, index=False)
|
| 42 |
+
print(f"\nWrote {len(feats_df)} rows to {out}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train the multi-task XGBoost baseline.
|
| 2 |
+
|
| 3 |
+
Joins phenotypes + features, derives a `family` column from `species` for group K-fold,
|
| 4 |
+
and writes per-target metrics to artifacts/baseline_results.json.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from microbe_model import config
|
| 11 |
+
from microbe_model.train.baseline import save_results, train_all
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def derive_family(species: str | None) -> str:
|
| 15 |
+
"""Crude family proxy: first word of binomial. Replace with GTDB lookup later."""
|
| 16 |
+
if not species:
|
| 17 |
+
return "__unknown__"
|
| 18 |
+
return str(species).split()[0]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> None:
|
| 22 |
+
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
|
| 23 |
+
feats = pd.read_parquet(config.DATA / "features.parquet")
|
| 24 |
+
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 25 |
+
df["family"] = df["species"].apply(derive_family)
|
| 26 |
+
|
| 27 |
+
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
| 28 |
+
print(f"Training on {len(df)} strains × {len(feature_cols)} features.")
|
| 29 |
+
print(f"Group counts (top 10): {df['family'].value_counts().head(10).to_dict()}")
|
| 30 |
+
|
| 31 |
+
results = train_all(df, feature_cols)
|
| 32 |
+
|
| 33 |
+
out = config.ARTIFACTS / "baseline_results.json"
|
| 34 |
+
save_results(results, out)
|
| 35 |
+
print(f"\nWrote results to {out}\n")
|
| 36 |
+
for target, r in results.items():
|
| 37 |
+
if r.folds:
|
| 38 |
+
metric = r.folds[0].metric_name
|
| 39 |
+
print(f" {target:25s} {metric:10s} = {r.mean():.4f} (n_folds={len(r.folds)})")
|
| 40 |
+
else:
|
| 41 |
+
print(f" {target:25s} skipped (insufficient data)")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.0.1"
|
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Project paths and shared constants."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
ROOT = Path(__file__).resolve().parents[2]
|
| 12 |
+
DATA = ROOT / "data"
|
| 13 |
+
ARTIFACTS = ROOT / "artifacts"
|
| 14 |
+
|
| 15 |
+
BACDIVE_DIR = DATA / "bacdive"
|
| 16 |
+
GENOME_DIR = DATA / "genomes"
|
| 17 |
+
FEATURE_DIR = DATA / "features"
|
| 18 |
+
|
| 19 |
+
for _d in (DATA, ARTIFACTS, BACDIVE_DIR, GENOME_DIR, FEATURE_DIR):
|
| 20 |
+
_d.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
BACDIVE_USER = os.environ.get("BACDIVE_USER")
|
| 23 |
+
BACDIVE_PASSWORD = os.environ.get("BACDIVE_PASSWORD")
|
| 24 |
+
NCBI_API_KEY = os.environ.get("NCBI_API_KEY")
|
| 25 |
+
|
| 26 |
+
PHENOTYPE_TARGETS = {
|
| 27 |
+
"optimal_temperature_c": "regression",
|
| 28 |
+
"optimal_ph": "regression",
|
| 29 |
+
"oxygen_requirement": "classification",
|
| 30 |
+
"salt_tolerance_pct": "regression",
|
| 31 |
+
}
|
|
File without changes
|
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BacDive REST API client.
|
| 2 |
+
|
| 3 |
+
BacDive (https://bacdive.dsmz.de/) is the largest curated database of bacterial phenotypes.
|
| 4 |
+
Free registration is required; credentials are read from BACDIVE_USER / BACDIVE_PASSWORD.
|
| 5 |
+
|
| 6 |
+
This client does the minimum needed for v0:
|
| 7 |
+
- log in and obtain an OAuth token
|
| 8 |
+
- paginate through the strain catalog
|
| 9 |
+
- fetch full records by BacDive ID
|
| 10 |
+
- extract the phenotype targets we predict (T_opt, pH_opt, oxygen, salt)
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
from collections.abc import Iterator
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
|
| 22 |
+
from microbe_model import config
|
| 23 |
+
|
| 24 |
+
BASE_URL = "https://api.bacdive.dsmz.de"
|
| 25 |
+
TOKEN_URL = "https://sso.dsmz.de/auth/realms/dsmz/protocol/openid-connect/token"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BacDiveAuthError(RuntimeError):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BacDiveClient:
|
| 33 |
+
def __init__(self, user: str | None = None, password: str | None = None) -> None:
|
| 34 |
+
self.user = user or config.BACDIVE_USER
|
| 35 |
+
self.password = password or config.BACDIVE_PASSWORD
|
| 36 |
+
if not self.user or not self.password:
|
| 37 |
+
raise BacDiveAuthError(
|
| 38 |
+
"Set BACDIVE_USER and BACDIVE_PASSWORD in .env (register at bacdive.dsmz.de)."
|
| 39 |
+
)
|
| 40 |
+
self._token: str | None = None
|
| 41 |
+
self._token_expires_at: float = 0.0
|
| 42 |
+
self._session = requests.Session()
|
| 43 |
+
|
| 44 |
+
def _refresh_token(self) -> None:
|
| 45 |
+
resp = self._session.post(
|
| 46 |
+
TOKEN_URL,
|
| 47 |
+
data={
|
| 48 |
+
"grant_type": "password",
|
| 49 |
+
"client_id": "api.bacdive.public",
|
| 50 |
+
"username": self.user,
|
| 51 |
+
"password": self.password,
|
| 52 |
+
},
|
| 53 |
+
timeout=30,
|
| 54 |
+
)
|
| 55 |
+
if resp.status_code != 200:
|
| 56 |
+
raise BacDiveAuthError(f"BacDive auth failed: {resp.status_code} {resp.text}")
|
| 57 |
+
body = resp.json()
|
| 58 |
+
self._token = body["access_token"]
|
| 59 |
+
self._token_expires_at = time.time() + body.get("expires_in", 300) - 30
|
| 60 |
+
|
| 61 |
+
def _headers(self) -> dict[str, str]:
|
| 62 |
+
if self._token is None or time.time() >= self._token_expires_at:
|
| 63 |
+
self._refresh_token()
|
| 64 |
+
return {"Authorization": f"Bearer {self._token}", "Accept": "application/json"}
|
| 65 |
+
|
| 66 |
+
def _get(self, path: str, params: dict | None = None) -> dict[str, Any]:
|
| 67 |
+
url = f"{BASE_URL}{path}"
|
| 68 |
+
for attempt in range(3):
|
| 69 |
+
resp = self._session.get(url, headers=self._headers(), params=params, timeout=60)
|
| 70 |
+
if resp.status_code == 429:
|
| 71 |
+
time.sleep(2 ** attempt)
|
| 72 |
+
continue
|
| 73 |
+
resp.raise_for_status()
|
| 74 |
+
return resp.json()
|
| 75 |
+
resp.raise_for_status()
|
| 76 |
+
return {}
|
| 77 |
+
|
| 78 |
+
def iter_strain_ids(self, limit: int | None = None) -> Iterator[int]:
|
| 79 |
+
"""Page through the BacDive catalog and yield strain IDs."""
|
| 80 |
+
page_url: str | None = "/fetch/"
|
| 81 |
+
seen = 0
|
| 82 |
+
while page_url:
|
| 83 |
+
body = self._get(page_url)
|
| 84 |
+
for record in body.get("results", []):
|
| 85 |
+
yield int(record["id"])
|
| 86 |
+
seen += 1
|
| 87 |
+
if limit is not None and seen >= limit:
|
| 88 |
+
return
|
| 89 |
+
next_url = body.get("next")
|
| 90 |
+
if not next_url:
|
| 91 |
+
return
|
| 92 |
+
page_url = next_url.replace(BASE_URL, "")
|
| 93 |
+
|
| 94 |
+
def fetch_record(self, bacdive_id: int) -> dict[str, Any]:
|
| 95 |
+
body = self._get(f"/fetch/{bacdive_id}")
|
| 96 |
+
results = body.get("results") or {}
|
| 97 |
+
if isinstance(results, list):
|
| 98 |
+
return results[0] if results else {}
|
| 99 |
+
if isinstance(results, dict) and str(bacdive_id) in results:
|
| 100 |
+
return results[str(bacdive_id)]
|
| 101 |
+
return results
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def extract_phenotypes(record: dict[str, Any]) -> dict[str, Any]:
|
| 105 |
+
"""Pull the v0 prediction targets out of a BacDive record.
|
| 106 |
+
|
| 107 |
+
BacDive's record schema is deeply nested and field names vary across record versions.
|
| 108 |
+
We tolerate missing fields — anything we can't find becomes None and is dropped at training time.
|
| 109 |
+
"""
|
| 110 |
+
out: dict[str, Any] = {
|
| 111 |
+
"bacdive_id": record.get("General", {}).get("BacDive-ID"),
|
| 112 |
+
"species": record.get("Name and taxonomic classification", {}).get("species"),
|
| 113 |
+
"ncbi_taxon_id": record.get("General", {}).get("NCBI tax id"),
|
| 114 |
+
"optimal_temperature_c": None,
|
| 115 |
+
"optimal_ph": None,
|
| 116 |
+
"oxygen_requirement": None,
|
| 117 |
+
"salt_tolerance_pct": None,
|
| 118 |
+
"genome_accession": None,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
culture = record.get("Culture and growth conditions", {})
|
| 122 |
+
temps = _as_list(culture.get("culture temp"))
|
| 123 |
+
for t in temps:
|
| 124 |
+
if isinstance(t, dict) and t.get("type", "").lower() in {"optimum", "optimal"}:
|
| 125 |
+
out["optimal_temperature_c"] = _to_float(t.get("temperature"))
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
phs = _as_list(culture.get("culture pH"))
|
| 129 |
+
for p in phs:
|
| 130 |
+
if isinstance(p, dict) and p.get("type", "").lower() in {"optimum", "optimal"}:
|
| 131 |
+
out["optimal_ph"] = _to_float(p.get("pH"))
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
physio = record.get("Physiology and metabolism", {})
|
| 135 |
+
oxygen = _as_list(physio.get("oxygen tolerance"))
|
| 136 |
+
if oxygen and isinstance(oxygen[0], dict):
|
| 137 |
+
out["oxygen_requirement"] = oxygen[0].get("oxygen tolerance")
|
| 138 |
+
|
| 139 |
+
salt = _as_list(physio.get("halophily"))
|
| 140 |
+
for s in salt:
|
| 141 |
+
if isinstance(s, dict) and "concentration" in s:
|
| 142 |
+
out["salt_tolerance_pct"] = _to_float(s.get("concentration"))
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
seq = record.get("Sequence information", {})
|
| 146 |
+
genomes = _as_list(seq.get("genome sequence"))
|
| 147 |
+
for g in genomes:
|
| 148 |
+
if isinstance(g, dict) and g.get("accession"):
|
| 149 |
+
out["genome_accession"] = g["accession"]
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _as_list(x: Any) -> list:
|
| 156 |
+
if x is None:
|
| 157 |
+
return []
|
| 158 |
+
if isinstance(x, list):
|
| 159 |
+
return x
|
| 160 |
+
return [x]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _to_float(x: Any) -> float | None:
|
| 164 |
+
if x is None:
|
| 165 |
+
return None
|
| 166 |
+
try:
|
| 167 |
+
return float(str(x).split()[0])
|
| 168 |
+
except (ValueError, AttributeError):
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def cache_path(bacdive_id: int) -> Path:
|
| 173 |
+
return config.BACDIVE_DIR / f"{bacdive_id}.json"
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def fetch_with_cache(client: BacDiveClient, bacdive_id: int) -> dict[str, Any]:
|
| 177 |
+
path = cache_path(bacdive_id)
|
| 178 |
+
if path.exists():
|
| 179 |
+
return json.loads(path.read_text())
|
| 180 |
+
record = client.fetch_record(bacdive_id)
|
| 181 |
+
path.write_text(json.dumps(record))
|
| 182 |
+
return record
|
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""NCBI genome fetcher.
|
| 2 |
+
|
| 3 |
+
Uses the NCBI Datasets v2 REST API to download a single nucleotide FASTA per accession.
|
| 4 |
+
This API doesn't require auth, but providing NCBI_API_KEY raises the rate limit from
|
| 5 |
+
3 req/s to 10 req/s.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import gzip
|
| 10 |
+
import io
|
| 11 |
+
import time
|
| 12 |
+
import zipfile
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
|
| 17 |
+
from microbe_model import config
|
| 18 |
+
|
| 19 |
+
DATASETS_BASE = "https://api.ncbi.nlm.nih.gov/datasets/v2"
|
| 20 |
+
RATE_LIMIT_S = 0.1 if config.NCBI_API_KEY else 0.34
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GenomeNotFound(RuntimeError):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def genome_path(accession: str) -> Path:
|
| 28 |
+
return config.GENOME_DIR / f"{accession}.fna.gz"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def fetch_genome(accession: str, *, force: bool = False) -> Path:
|
| 32 |
+
"""Download a genome FASTA for the given assembly accession (e.g. GCF_000005845.2).
|
| 33 |
+
|
| 34 |
+
The Datasets API returns a zip; we extract the FASTA, gzip it, and write to disk.
|
| 35 |
+
Idempotent — returns immediately if the file is already cached.
|
| 36 |
+
"""
|
| 37 |
+
out = genome_path(accession)
|
| 38 |
+
if out.exists() and not force:
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
url = f"{DATASETS_BASE}/genome/accession/{accession}/download"
|
| 42 |
+
params = {"include_annotation_type": "GENOME_FASTA"}
|
| 43 |
+
headers = {"Accept": "application/zip"}
|
| 44 |
+
if config.NCBI_API_KEY:
|
| 45 |
+
headers["api-key"] = config.NCBI_API_KEY
|
| 46 |
+
|
| 47 |
+
time.sleep(RATE_LIMIT_S)
|
| 48 |
+
resp = requests.get(url, params=params, headers=headers, timeout=120, stream=True)
|
| 49 |
+
if resp.status_code == 404:
|
| 50 |
+
raise GenomeNotFound(accession)
|
| 51 |
+
resp.raise_for_status()
|
| 52 |
+
|
| 53 |
+
buf = io.BytesIO(resp.content)
|
| 54 |
+
with zipfile.ZipFile(buf) as zf:
|
| 55 |
+
fasta_names = [n for n in zf.namelist() if n.endswith(".fna")]
|
| 56 |
+
if not fasta_names:
|
| 57 |
+
raise GenomeNotFound(f"{accession} (no .fna in archive)")
|
| 58 |
+
with zf.open(fasta_names[0]) as src, gzip.open(out, "wb") as dst:
|
| 59 |
+
for chunk in iter(lambda: src.read(1 << 16), b""):
|
| 60 |
+
dst.write(chunk)
|
| 61 |
+
return out
|
|
File without changes
|
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tabular feature extraction from a microbial genome FASTA.
|
| 2 |
+
|
| 3 |
+
These features are deliberately simple and biologically motivated:
|
| 4 |
+
- genome size, GC content, coding density
|
| 5 |
+
- predicted gene count and mean CDS length
|
| 6 |
+
- proteome-level amino acid composition
|
| 7 |
+
- aromatic, charged, and IVYWREL fractions (correlate with growth temperature)
|
| 8 |
+
- mean isoelectric point and hydrophobicity
|
| 9 |
+
|
| 10 |
+
The amino-acid-composition signals have well-established correlations with optimal growth
|
| 11 |
+
temperature and pH (Zeldovich 2007; Tekaia 2002), so they give XGBoost real signal to learn from
|
| 12 |
+
without any deep model.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import gzip
|
| 17 |
+
from collections import Counter
|
| 18 |
+
from collections.abc import Iterable
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pyrodigal
|
| 23 |
+
from Bio import SeqIO
|
| 24 |
+
|
| 25 |
+
AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWY"
|
| 26 |
+
AA_AROMATIC = set("FWY")
|
| 27 |
+
AA_CHARGED_POS = set("KRH")
|
| 28 |
+
AA_CHARGED_NEG = set("DE")
|
| 29 |
+
AA_IVYWREL = set("IVYWREL") # thermophile signature (Zeldovich 2007)
|
| 30 |
+
|
| 31 |
+
# Kyte-Doolittle hydrophobicity
|
| 32 |
+
HYDROPHOBICITY = {
|
| 33 |
+
"A": 1.8, "C": 2.5, "D": -3.5, "E": -3.5, "F": 2.8, "G": -0.4, "H": -3.2,
|
| 34 |
+
"I": 4.5, "K": -3.9, "L": 3.8, "M": 1.9, "N": -3.5, "P": -1.6, "Q": -3.5,
|
| 35 |
+
"R": -4.5, "S": -0.8, "T": -0.7, "V": 4.2, "W": -0.9, "Y": -1.3,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# pKa values for isoelectric point estimation (Lehninger)
|
| 39 |
+
PKA_NTERM = 9.69
|
| 40 |
+
PKA_CTERM = 2.34
|
| 41 |
+
PKA_SIDE = {"D": 3.65, "E": 4.25, "C": 8.33, "Y": 10.07, "H": 6.00, "K": 10.53, "R": 12.48}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def read_fasta_records(path: Path) -> Iterable[tuple[str, str]]:
|
| 45 |
+
opener = gzip.open if str(path).endswith(".gz") else open
|
| 46 |
+
with opener(path, "rt") as handle:
|
| 47 |
+
for record in SeqIO.parse(handle, "fasta"):
|
| 48 |
+
yield record.id, str(record.seq).upper()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def predict_proteins(contigs: Iterable[tuple[str, str]]) -> tuple[list[str], int]:
|
| 52 |
+
"""Run Pyrodigal in meta mode and return predicted protein sequences + total nucleotides scanned."""
|
| 53 |
+
finder = pyrodigal.GeneFinder(meta=True)
|
| 54 |
+
proteins: list[str] = []
|
| 55 |
+
total_nt = 0
|
| 56 |
+
for _name, seq in contigs:
|
| 57 |
+
total_nt += len(seq)
|
| 58 |
+
# Pyrodigal accepts bytes; uppercase string works too in recent versions
|
| 59 |
+
genes = finder.find_genes(seq.encode("ascii"))
|
| 60 |
+
for gene in genes:
|
| 61 |
+
proteins.append(gene.translate().rstrip("*"))
|
| 62 |
+
return proteins, total_nt
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def aa_composition(proteins: list[str]) -> dict[str, float]:
|
| 66 |
+
counts: Counter[str] = Counter()
|
| 67 |
+
total = 0
|
| 68 |
+
for p in proteins:
|
| 69 |
+
counts.update(p)
|
| 70 |
+
total += len(p)
|
| 71 |
+
if total == 0:
|
| 72 |
+
return {f"aa_frac_{a}": 0.0 for a in AA_ALPHABET}
|
| 73 |
+
return {f"aa_frac_{a}": counts.get(a, 0) / total for a in AA_ALPHABET}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _isoelectric_point(seq: str) -> float:
|
| 77 |
+
"""Bisection over pH to find the point where net charge is zero."""
|
| 78 |
+
if not seq:
|
| 79 |
+
return 7.0
|
| 80 |
+
counts = Counter(seq)
|
| 81 |
+
lo, hi = 0.0, 14.0
|
| 82 |
+
for _ in range(50):
|
| 83 |
+
ph = (lo + hi) / 2
|
| 84 |
+
net = (
|
| 85 |
+
1 / (1 + 10 ** (ph - PKA_NTERM))
|
| 86 |
+
- 1 / (1 + 10 ** (PKA_CTERM - ph))
|
| 87 |
+
+ counts.get("K", 0) / (1 + 10 ** (ph - PKA_SIDE["K"]))
|
| 88 |
+
+ counts.get("R", 0) / (1 + 10 ** (ph - PKA_SIDE["R"]))
|
| 89 |
+
+ counts.get("H", 0) / (1 + 10 ** (ph - PKA_SIDE["H"]))
|
| 90 |
+
- counts.get("D", 0) / (1 + 10 ** (PKA_SIDE["D"] - ph))
|
| 91 |
+
- counts.get("E", 0) / (1 + 10 ** (PKA_SIDE["E"] - ph))
|
| 92 |
+
- counts.get("C", 0) / (1 + 10 ** (PKA_SIDE["C"] - ph))
|
| 93 |
+
- counts.get("Y", 0) / (1 + 10 ** (PKA_SIDE["Y"] - ph))
|
| 94 |
+
)
|
| 95 |
+
if net > 0:
|
| 96 |
+
lo = ph
|
| 97 |
+
else:
|
| 98 |
+
hi = ph
|
| 99 |
+
return (lo + hi) / 2
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def extract_features(fasta_path: Path) -> dict[str, float]:
|
| 103 |
+
contigs = list(read_fasta_records(fasta_path))
|
| 104 |
+
nt_total = sum(len(s) for _, s in contigs)
|
| 105 |
+
gc = sum(s.count("G") + s.count("C") for _, s in contigs)
|
| 106 |
+
gc_frac = gc / nt_total if nt_total else 0.0
|
| 107 |
+
|
| 108 |
+
proteins, _ = predict_proteins(contigs)
|
| 109 |
+
aa_total = sum(len(p) for p in proteins)
|
| 110 |
+
coding_density = (3 * aa_total) / nt_total if nt_total else 0.0
|
| 111 |
+
|
| 112 |
+
composition = aa_composition(proteins)
|
| 113 |
+
|
| 114 |
+
aromatic = sum(composition[f"aa_frac_{a}"] for a in AA_AROMATIC)
|
| 115 |
+
pos_charged = sum(composition[f"aa_frac_{a}"] for a in AA_CHARGED_POS)
|
| 116 |
+
neg_charged = sum(composition[f"aa_frac_{a}"] for a in AA_CHARGED_NEG)
|
| 117 |
+
ivywrel = sum(composition[f"aa_frac_{a}"] for a in AA_IVYWREL)
|
| 118 |
+
|
| 119 |
+
hydrophobicity = (
|
| 120 |
+
sum(composition[f"aa_frac_{a}"] * HYDROPHOBICITY[a] for a in AA_ALPHABET)
|
| 121 |
+
if proteins else 0.0
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
pi_values = [_isoelectric_point(p) for p in proteins[:1000]] # cap at 1k proteins for speed
|
| 125 |
+
mean_pi = float(np.mean(pi_values)) if pi_values else 7.0
|
| 126 |
+
|
| 127 |
+
cds_lengths = [len(p) for p in proteins]
|
| 128 |
+
return {
|
| 129 |
+
"genome_size_nt": float(nt_total),
|
| 130 |
+
"n_contigs": float(len(contigs)),
|
| 131 |
+
"gc_content": gc_frac,
|
| 132 |
+
"n_predicted_cds": float(len(proteins)),
|
| 133 |
+
"coding_density": coding_density,
|
| 134 |
+
"mean_cds_aa_length": float(np.mean(cds_lengths)) if cds_lengths else 0.0,
|
| 135 |
+
"median_cds_aa_length": float(np.median(cds_lengths)) if cds_lengths else 0.0,
|
| 136 |
+
"aromatic_frac": aromatic,
|
| 137 |
+
"pos_charged_frac": pos_charged,
|
| 138 |
+
"neg_charged_frac": neg_charged,
|
| 139 |
+
"ivywrel_frac": ivywrel,
|
| 140 |
+
"mean_hydrophobicity": hydrophobicity,
|
| 141 |
+
"mean_isoelectric_point": mean_pi,
|
| 142 |
+
**composition,
|
| 143 |
+
}
|
|
File without changes
|
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-task XGBoost baseline.
|
| 2 |
+
|
| 3 |
+
One model per phenotype target, evaluated with group K-fold by taxonomic family to prevent
|
| 4 |
+
leakage from closely-related strains. This is the v0 "what's the floor on tabular performance"
|
| 5 |
+
sanity check before we invest in transformers.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import xgboost as xgb
|
| 16 |
+
from sklearn.metrics import f1_score, mean_absolute_error
|
| 17 |
+
from sklearn.model_selection import GroupKFold
|
| 18 |
+
from sklearn.preprocessing import LabelEncoder
|
| 19 |
+
|
| 20 |
+
from microbe_model import config
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class FoldResult:
|
| 25 |
+
target: str
|
| 26 |
+
task: str
|
| 27 |
+
metric_name: str
|
| 28 |
+
value: float
|
| 29 |
+
n_train: int
|
| 30 |
+
n_test: int
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class TargetResult:
|
| 35 |
+
target: str
|
| 36 |
+
task: str
|
| 37 |
+
folds: list[FoldResult] = field(default_factory=list)
|
| 38 |
+
importances: dict[str, float] = field(default_factory=dict)
|
| 39 |
+
|
| 40 |
+
def mean(self) -> float:
|
| 41 |
+
return float(np.mean([f.value for f in self.folds])) if self.folds else float("nan")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _select_xy(df: pd.DataFrame, target: str, feature_cols: list[str]) -> tuple[pd.DataFrame, pd.Series]:
|
| 45 |
+
mask = df[target].notna()
|
| 46 |
+
return df.loc[mask, feature_cols], df.loc[mask, target]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def train_target(
|
| 50 |
+
df: pd.DataFrame,
|
| 51 |
+
target: str,
|
| 52 |
+
task: str,
|
| 53 |
+
feature_cols: list[str],
|
| 54 |
+
group_col: str = "family",
|
| 55 |
+
n_splits: int = 5,
|
| 56 |
+
) -> TargetResult:
|
| 57 |
+
X, y = _select_xy(df, target, feature_cols)
|
| 58 |
+
groups = df.loc[X.index, group_col].fillna("__unknown__")
|
| 59 |
+
if len(X) < n_splits * 2:
|
| 60 |
+
return TargetResult(target=target, task=task)
|
| 61 |
+
|
| 62 |
+
if task == "classification":
|
| 63 |
+
encoder = LabelEncoder()
|
| 64 |
+
y_enc = encoder.fit_transform(y.astype(str))
|
| 65 |
+
else:
|
| 66 |
+
y_enc = y.to_numpy(dtype=float)
|
| 67 |
+
|
| 68 |
+
n_unique_groups = groups.nunique()
|
| 69 |
+
splits = min(n_splits, max(2, n_unique_groups))
|
| 70 |
+
kfold = GroupKFold(n_splits=splits)
|
| 71 |
+
|
| 72 |
+
result = TargetResult(target=target, task=task)
|
| 73 |
+
importance_acc = np.zeros(len(feature_cols), dtype=float)
|
| 74 |
+
fold_count = 0
|
| 75 |
+
|
| 76 |
+
for tr_idx, te_idx in kfold.split(X, y_enc, groups):
|
| 77 |
+
if task == "classification":
|
| 78 |
+
n_classes = len(np.unique(y_enc[tr_idx]))
|
| 79 |
+
if n_classes < 2:
|
| 80 |
+
continue
|
| 81 |
+
model = xgb.XGBClassifier(
|
| 82 |
+
n_estimators=300,
|
| 83 |
+
max_depth=5,
|
| 84 |
+
learning_rate=0.05,
|
| 85 |
+
tree_method="hist",
|
| 86 |
+
n_jobs=-1,
|
| 87 |
+
eval_metric="mlogloss",
|
| 88 |
+
)
|
| 89 |
+
model.fit(X.iloc[tr_idx], y_enc[tr_idx])
|
| 90 |
+
preds = model.predict(X.iloc[te_idx])
|
| 91 |
+
score = f1_score(y_enc[te_idx], preds, average="macro")
|
| 92 |
+
metric = "f1_macro"
|
| 93 |
+
else:
|
| 94 |
+
model = xgb.XGBRegressor(
|
| 95 |
+
n_estimators=500,
|
| 96 |
+
max_depth=5,
|
| 97 |
+
learning_rate=0.05,
|
| 98 |
+
tree_method="hist",
|
| 99 |
+
n_jobs=-1,
|
| 100 |
+
)
|
| 101 |
+
model.fit(X.iloc[tr_idx], y_enc[tr_idx])
|
| 102 |
+
preds = model.predict(X.iloc[te_idx])
|
| 103 |
+
score = mean_absolute_error(y_enc[te_idx], preds)
|
| 104 |
+
metric = "mae"
|
| 105 |
+
|
| 106 |
+
result.folds.append(FoldResult(
|
| 107 |
+
target=target,
|
| 108 |
+
task=task,
|
| 109 |
+
metric_name=metric,
|
| 110 |
+
value=float(score),
|
| 111 |
+
n_train=int(len(tr_idx)),
|
| 112 |
+
n_test=int(len(te_idx)),
|
| 113 |
+
))
|
| 114 |
+
importance_acc += model.feature_importances_
|
| 115 |
+
fold_count += 1
|
| 116 |
+
|
| 117 |
+
if fold_count:
|
| 118 |
+
importance_acc /= fold_count
|
| 119 |
+
result.importances = dict(zip(feature_cols, importance_acc.tolist(), strict=True))
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def train_all(df: pd.DataFrame, feature_cols: list[str]) -> dict[str, TargetResult]:
|
| 124 |
+
results: dict[str, TargetResult] = {}
|
| 125 |
+
for target, task in config.PHENOTYPE_TARGETS.items():
|
| 126 |
+
if target not in df.columns:
|
| 127 |
+
continue
|
| 128 |
+
results[target] = train_target(df, target, task, feature_cols)
|
| 129 |
+
return results
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def save_results(results: dict[str, TargetResult], path: Path) -> None:
|
| 133 |
+
payload = {
|
| 134 |
+
target: {
|
| 135 |
+
"task": r.task,
|
| 136 |
+
"mean_metric": r.mean(),
|
| 137 |
+
"folds": [f.__dict__ for f in r.folds],
|
| 138 |
+
"top_features": dict(
|
| 139 |
+
sorted(r.importances.items(), key=lambda kv: kv[1], reverse=True)[:20]
|
| 140 |
+
),
|
| 141 |
+
}
|
| 142 |
+
for target, r in results.items()
|
| 143 |
+
}
|
| 144 |
+
path.write_text(json.dumps(payload, indent=2))
|
|
File without changes
|
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Smoke test the feature extractor on a tiny synthetic genome."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import gzip
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from microbe_model.features.genome import extract_features
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _write_fake_genome(path: Path) -> None:
|
| 11 |
+
"""Write a tiny FASTA with two contigs of synthetic GC-balanced sequence."""
|
| 12 |
+
contigs = [
|
| 13 |
+
(">contig_1\n" + ("ATGCGTACGTAGCTAGCTAGCATGCGTACG" * 200) + "\n"),
|
| 14 |
+
(">contig_2\n" + ("CGTACGATCGATCGTACGTAGCTACGATGC" * 200) + "\n"),
|
| 15 |
+
]
|
| 16 |
+
with gzip.open(path, "wt") as fh:
|
| 17 |
+
fh.write("".join(contigs))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_extract_features_runs(tmp_path: Path) -> None:
|
| 21 |
+
fasta = tmp_path / "fake.fna.gz"
|
| 22 |
+
_write_fake_genome(fasta)
|
| 23 |
+
|
| 24 |
+
feats = extract_features(fasta)
|
| 25 |
+
|
| 26 |
+
assert feats["genome_size_nt"] > 0
|
| 27 |
+
assert 0 <= feats["gc_content"] <= 1
|
| 28 |
+
assert feats["n_contigs"] == 2
|
| 29 |
+
assert feats["n_predicted_cds"] >= 0 # synthetic seq may have no real ORFs
|
| 30 |
+
|
| 31 |
+
# Amino acid fractions should sum to ~1 if any proteins were found, else 0.
|
| 32 |
+
aa_total = sum(v for k, v in feats.items() if k.startswith("aa_frac_"))
|
| 33 |
+
assert aa_total == 0.0 or abs(aa_total - 1.0) < 1e-6
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_isoelectric_point_in_range() -> None:
|
| 37 |
+
from microbe_model.features.genome import _isoelectric_point
|
| 38 |
+
|
| 39 |
+
assert 0 <= _isoelectric_point("AAAAA") <= 14
|
| 40 |
+
assert 0 <= _isoelectric_point("DDDDD") <= 14
|
| 41 |
+
assert 0 <= _isoelectric_point("KKKKK") <= 14
|
| 42 |
+
# Acidic protein should have lower pI than basic
|
| 43 |
+
assert _isoelectric_point("DDDDD") < _isoelectric_point("KKKKK")
|
|
The diff for this file is too large to render.
See raw diff
|
|
|