Miyu Horiuchi
Fix _derive_salt to pick optimum entries — salt MAE 2.52 → 2.17 (-13.7%)
56b0c4e
"""BacDive REST API client (v2, public).
The BacDive v2 API is fully open as of February 2026 — no registration, no auth.
Documentation: https://api.bacdive.dsmz.de/
We discover strain IDs by scanning the integer ID space in semicolon-batched fetches
of up to 100 IDs per call. Missing IDs are silently dropped server-side, so a blind
scan over [1, MAX_ID] yields every existing record in one pass. At ~150K live IDs
(as of 2026-04), this takes ~30 minutes single-threaded.
"""
from __future__ import annotations
import json
import time
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import requests
from microbe_model import config
BASE_URL = "https://api.bacdive.dsmz.de/v2"
BATCH_SIZE = 100 # max IDs per /fetch/ call (server limit)
DEFAULT_MAX_ID = 200_000 # conservative upper bound; live max is ~160K-180K as of 2026-04
class BacDiveClient:
def __init__(self, *, request_timeout: int = 60, retry_sleep_s: float = 1.0) -> None:
self._session = requests.Session()
self.timeout = request_timeout
self.retry_sleep_s = retry_sleep_s
def _get(self, path: str, params: dict | None = None) -> dict[str, Any]:
url = f"{BASE_URL}{path}"
for attempt in range(3):
resp = self._session.get(url, params=params, timeout=self.timeout)
if resp.status_code == 429:
time.sleep(self.retry_sleep_s * (attempt + 1))
continue
resp.raise_for_status()
return resp.json()
resp.raise_for_status()
return {}
def fetch_batch(self, ids: list[int]) -> dict[int, dict[str, Any]]:
"""Fetch up to BATCH_SIZE strain records in a single call.
Returns a {bacdive_id: record} mapping. Missing IDs are absent from the result.
"""
if not ids:
return {}
if len(ids) > BATCH_SIZE:
raise ValueError(f"Batch size {len(ids)} exceeds server limit {BATCH_SIZE}")
path = f"/fetch/{';'.join(str(i) for i in ids)}"
body = self._get(path)
results = body.get("results")
if isinstance(results, dict):
return {int(k): v for k, v in results.items()}
return {}
def iter_records(
self,
*,
start: int = 1,
end: int = DEFAULT_MAX_ID,
batch_size: int = BATCH_SIZE,
) -> Iterator[tuple[int, dict[str, Any]]]:
"""Scan the BacDive ID range and yield (id, record) for every existing strain."""
for batch_start in range(start, end + 1, batch_size):
batch_end = min(batch_start + batch_size - 1, end)
ids = list(range(batch_start, batch_end + 1))
records = self.fetch_batch(ids)
yield from sorted(records.items())
def cache_path(bacdive_id: int) -> Path:
return config.BACDIVE_DIR / f"{bacdive_id}.json"
def cache_record(bacdive_id: int, record: dict[str, Any]) -> Path:
path = cache_path(bacdive_id)
path.write_text(json.dumps(record))
return path
def load_cached(bacdive_id: int) -> dict[str, Any] | None:
path = cache_path(bacdive_id)
if not path.exists():
return None
return json.loads(path.read_text())
def extract_phenotypes(record: dict[str, Any]) -> dict[str, Any]:
"""Pull the v0 prediction targets out of a BacDive v2 record.
Field locations (verified against live API on 2026-04-26):
- General → BacDive-ID
- Name and taxonomic classification → species, genus, family
- Culture and growth conditions → culture temp[] (type ∈ {growth, optimum, range, no growth})
- Culture and growth conditions → culture pH[] (same shape)
- Physiology and metabolism → oxygen tolerance[]
- Physiology and metabolism → halophily[]
- Sequence information → Genome sequences[].INSDC accession
- Isolation, sampling and environmental information → isolation source categories[].Cat{1,2,3}
"""
general = record.get("General") or {}
taxon = record.get("Name and taxonomic classification") or {}
culture = record.get("Culture and growth conditions") or {}
physio = record.get("Physiology and metabolism") or {}
seq = record.get("Sequence information") or {}
iso = record.get("Isolation, sampling and environmental information") or {}
iso_cats = _collect_isolation_categories(iso.get("isolation source categories"))
out: dict[str, Any] = {
"bacdive_id": general.get("BacDive-ID"),
"species": taxon.get("species"),
"genus": taxon.get("genus"),
"family": (taxon.get("LPSN") or {}).get("family") or taxon.get("family"),
"ncbi_taxon_id": _first_ncbi_tax_id(general.get("NCBI tax id")),
"optimal_temperature_c": _derive_optimum(_as_list(culture.get("culture temp")), "temperature"),
"optimal_ph": _derive_optimum(_as_list(culture.get("culture pH")), "pH"),
"oxygen_requirement": _first_value(_as_list(physio.get("oxygen tolerance")), "oxygen tolerance"),
"salt_tolerance_pct": _derive_salt(physio.get("halophily")),
"genome_accession": _first_genome_accession(seq.get("Genome sequences")),
"isolation_cat1": iso_cats["cat1"],
"isolation_cat2": iso_cats["cat2"],
"isolation_cat3": iso_cats["cat3"],
}
return out
def _collect_isolation_categories(raw: Any) -> dict[str, str | None]:
"""Flatten BacDive's `isolation source categories` into 3 pipe-joined string fields.
A strain commonly has multiple parallel category descriptions (e.g., #Host=Human AND
#Host Body Product=Blood). We collect *all* unique values per level into a sorted,
pipe-joined string so downstream code can split & one-hot. The leading '#' is stripped.
"""
cats: dict[str, set[str]] = {"Cat1": set(), "Cat2": set(), "Cat3": set()}
for entry in _as_list(raw):
if not isinstance(entry, dict):
continue
for level in cats:
value = entry.get(level)
if isinstance(value, str) and value:
cats[level].add(value.lstrip("#").strip())
return {
"cat1": "|".join(sorted(cats["Cat1"])) or None,
"cat2": "|".join(sorted(cats["Cat2"])) or None,
"cat3": "|".join(sorted(cats["Cat3"])) or None,
}
def _as_list(x: Any) -> list:
if x is None:
return []
if isinstance(x, list):
return x
return [x]
def _to_float(x: Any) -> float | None:
if x is None:
return None
s = str(x).strip()
if not s:
return None
if "-" in s and not s.startswith("-"):
# e.g. "5-30" — return midpoint
parts = s.split("-")
try:
lo, hi = float(parts[0]), float(parts[1])
return (lo + hi) / 2
except (ValueError, IndexError):
return None
try:
return float(s.split()[0])
except (ValueError, AttributeError):
return None
def _derive_optimum(entries: list, value_key: str) -> float | None:
"""Find an optimum for a temperature- or pH-like list of {type, value} entries.
Preference order:
1. type == "optimum" (exact)
2. median of "positive growth" entries
3. None
"""
optima = []
growth = []
for entry in entries:
if not isinstance(entry, dict):
continue
etype = (entry.get("type") or "").lower()
value = _to_float(entry.get(value_key))
if value is None:
continue
is_positive = (entry.get("growth") or "").lower() in {"positive", "yes", "+", "true"}
if "optim" in etype:
optima.append(value)
elif etype == "growth" and is_positive:
growth.append(value)
if optima:
return sum(optima) / len(optima)
if growth:
sorted_g = sorted(growth)
n = len(sorted_g)
return sorted_g[n // 2] if n % 2 else (sorted_g[n // 2 - 1] + sorted_g[n // 2]) / 2
return None
def _first_value(entries: list, key: str) -> str | None:
for entry in entries:
if isinstance(entry, dict) and entry.get(key):
return str(entry[key])
return None
def _derive_salt(halophily: Any) -> float | None:
"""Derive optimal NaCl concentration (% w/v) from BacDive halophily entries.
Each entry has shape:
{salt: 'NaCl', growth: 'positive'|'no', tested relation: 'optimum'|'growth', concentration: '3 %'}
Preference order (mirrors _derive_optimum):
1. tested relation == 'optimum' AND growth == 'positive'
2. median of positive-growth concentrations (the strain's tolerated range)
3. None
The previous implementation returned the first parsable value, which often picked
the lowest tested concentration or a no-growth entry — overstating salt sensitivity.
"""
positive_tokens = {"positive", "yes", "+", "true"}
optima: list[float] = []
growth: list[float] = []
for entry in _as_list(halophily):
if not isinstance(entry, dict):
continue
if (entry.get("salt") or "NaCl") != "NaCl":
continue
is_positive = (entry.get("growth") or "").lower() in positive_tokens
relation = (entry.get("tested relation") or "").lower()
value = _to_float(entry.get("concentration") or entry.get("salt concentration"))
if value is None:
continue
if "optim" in relation and is_positive:
optima.append(value)
elif is_positive:
growth.append(value)
if optima:
return sum(optima) / len(optima)
if growth:
sorted_g = sorted(growth)
n = len(sorted_g)
return sorted_g[n // 2] if n % 2 else (sorted_g[n // 2 - 1] + sorted_g[n // 2]) / 2
return None
def _first_genome_accession(genome_entries: Any) -> str | None:
for entry in _as_list(genome_entries):
if isinstance(entry, dict):
for key in ("INSDC accession", "NCBI accession", "accession"):
value = entry.get(key)
if value:
return str(value)
return None
def _first_ncbi_tax_id(tax: Any) -> int | None:
for entry in _as_list(tax):
if isinstance(entry, dict):
value = entry.get("NCBI tax id")
if value is not None:
try:
return int(value)
except (ValueError, TypeError):
continue
return None