microbe-model / scripts /40_extract_marker_sequences_local.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Extract HMM-gated marker protein sequences locally.
This is the local CPU fallback for ``scripts/36_extract_marker_sequences.py`` when
Modal is unavailable. It emits the same JSONL schema expected by
``scripts/39_predict_hybrid.py``:
{
"bacdive_id": 1000000000,
"genome_accession": "GCA_...",
"by_category": {"oxygen": ["M..."], ...},
"category_counts": {"oxygen": 3, ...}
}
Example for the 5,000 uncultured UI genomes:
PYTHONPATH=src uv run --python 3.11 python scripts/40_extract_marker_sequences_local.py \
--input-path data/gtdb_candidates.parquet \
--id-col "" \
--accession-col genome_accession \
--fetch-accession-col ncbi_assembly_accession_versioned \
--out-path data/uncultured_marker_sequences.jsonl \
--workers 6
"""
from __future__ import annotations
import argparse
import io
import json
import os
import time
import zipfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Any
import pandas as pd
DATASETS_URL = "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{acc}/download"
VERSION_FALLBACKS = (".1", ".2", ".3", ".4")
EMPTY_ZIP_BYTES = 2_000
EVALUE_THRESHOLD = 1e-5
MAX_PROTEIN_LEN = 1022
MARKER_TO_CATEGORY: dict[str, str] = {
"Hsp70_DnaK": "temperature",
"Hsp90": "temperature",
"Cpn60_GroEL": "temperature",
"Hsp20": "temperature",
"CSD_cold_shock": "temperature",
"TGS_thermosome": "temperature",
"ATP_synth_alphabeta": "ph",
"ATP_synth_alphabeta_C": "ph",
"ATP_synth_F0_B": "ph",
"NhaA_Na_H_exch": "ph",
"NhaB_Na_H_exch": "ph",
"Pyridoxal_decarbox": "ph",
"MotA_TolQ_ExbB": "ph",
"V_ATPase_subH_N": "ph",
"COX1_aerobic": "oxygen",
"COX2_TM_aerobic": "oxygen",
"COX2_periplasm_aero": "oxygen",
"Cyt_CBB3_microaero": "oxygen",
"Rieske_2Fe2S": "oxygen",
"Catalase": "oxygen",
"SOD_FeMn": "oxygen",
"SOD_CuZn": "oxygen",
"FeFe_hyd_anaerobic": "oxygen",
"NiFe_hyd_anaerobic": "oxygen",
"FAD_binding_FrdA": "oxygen",
"Fer4_FeS_4Fe4S": "oxygen",
"KdpD_osmosensor": "salt",
"TrkH_K_channel": "salt",
"BCCT_compatible": "salt",
"BPD_transp_1": "salt",
"EctC_ectoine_synth": "salt",
"Bact_rhodopsin": "salt",
"TP_methylase_B12": "vitamin",
"Peripla_BP_2": "vitamin",
"THF_DHG_CYH_folate": "vitamin",
"FolB_folate": "vitamin",
"PdxJ_pyridoxine": "vitamin",
"DHBP_riboflavin": "vitamin",
"NifH_nitrogenase": "nitrogen",
"NifDK_nitrogenase": "nitrogen",
"NIR_SIR_ferredoxin": "nitrogen",
"RuBisCO_large_form1": "carbon",
"RuBisCO_small_form1": "carbon",
"Alpha_amylase": "carbon",
"Cellulase_GH5": "carbon",
"CBM_cellulose": "carbon",
"Molybdopterin_OR": "special",
"UvrD_helicase_C": "special",
}
CATEGORIES = ["temperature", "ph", "oxygen", "salt", "vitamin", "nitrogen", "carbon", "special"]
_HMM_CACHE: list[Any] | None = None
_ALPHABET: Any | None = None
def _has_version(accession: str) -> bool:
return "." in accession and accession.rsplit(".", 1)[-1].isdigit()
def _candidate_accessions(accession: str) -> list[str]:
if _has_version(accession):
return [accession]
return [accession + v for v in VERSION_FALLBACKS]
def _parse_fasta(raw: bytes) -> list[tuple[str, str]]:
contigs: list[tuple[str, str]] = []
name = None
chunks: list[str] = []
for line in raw.decode("utf-8", errors="ignore").splitlines():
if line.startswith(">"):
if name is not None:
contigs.append((name, "".join(chunks).upper()))
name = line[1:].split()[0]
chunks = []
else:
chunks.append(line.strip())
if name is not None:
contigs.append((name, "".join(chunks).upper()))
return contigs
def _fetch_fasta_bytes(accession: str) -> list[tuple[str, str]] | None:
import requests
headers = {"Accept": "application/zip"}
ncbi_key = os.environ.get("NCBI_API_KEY")
if ncbi_key:
headers["api-key"] = ncbi_key
for cand in _candidate_accessions(accession):
for attempt in range(3):
try:
time.sleep(0.1 if ncbi_key else 0.34)
resp = requests.get(
DATASETS_URL.format(acc=cand),
params={"include_annotation_type": "GENOME_FASTA"},
headers=headers,
timeout=120,
)
if resp.status_code == 404:
break
if resp.status_code in (429, 502, 503):
time.sleep(2**attempt)
continue
resp.raise_for_status()
except requests.RequestException:
if attempt == 2:
break
time.sleep(2**attempt)
continue
if len(resp.content) < EMPTY_ZIP_BYTES:
break
try:
with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
fasta_names = [name for name in zf.namelist() if name.endswith(".fna")]
if not fasta_names:
break
with zf.open(fasta_names[0]) as src:
return _parse_fasta(src.read())
except zipfile.BadZipFile:
break
return None
def _predict_proteins(contigs: list[tuple[str, str]]) -> list[str]:
import pyrodigal
if not contigs:
return []
total = sum(len(seq) for _, seq in contigs)
meta = total < 100_000
finder = pyrodigal.GeneFinder(meta=meta)
if not meta:
finder.train(*[seq.encode() for _, seq in contigs])
proteins: list[str] = []
for _, seq in contigs:
for gene in finder.find_genes(seq.encode()):
proteins.append(gene.translate().rstrip("*"))
return proteins
def _scan_for_markers(proteins: list[str], hmm_path: Path) -> dict[str, list[int]]:
import pyhmmer
import pyhmmer.easel
import pyhmmer.plan7
global _ALPHABET, _HMM_CACHE
result: dict[str, list[int]] = {name: [] for name in MARKER_TO_CATEGORY}
if not proteins:
return result
if _ALPHABET is None:
_ALPHABET = pyhmmer.easel.Alphabet.amino()
if _HMM_CACHE is None:
with pyhmmer.plan7.HMMFile(str(hmm_path)) as hmm_file:
_HMM_CACHE = list(hmm_file)
seqs = []
for idx, prot in enumerate(proteins):
if prot:
seqs.append(
pyhmmer.easel.TextSequence(name=f"p{idx}".encode(), sequence=prot).digitize(_ALPHABET)
)
if not seqs:
return result
for top_hits in pyhmmer.hmmer.hmmsearch(_HMM_CACHE, seqs, E=EVALUE_THRESHOLD):
raw = top_hits.query.name
marker = raw.decode() if isinstance(raw, bytes) else raw
if marker not in result:
continue
for hit in top_hits:
if hit.evalue > EVALUE_THRESHOLD:
continue
hit_name = hit.name.decode() if isinstance(hit.name, bytes) else hit.name
if hit_name.startswith("p"):
try:
result[marker].append(int(hit_name[1:]))
except ValueError:
pass
return result
def _extract_one(task: tuple[int, str, str, int, str]) -> dict[str, Any] | None:
record_id, genome_accession, fetch_accession, max_per_cat, hmm_path_str = task
contigs = _fetch_fasta_bytes(fetch_accession)
if not contigs:
return None
proteins = _predict_proteins(contigs)
if not proteins:
return None
marker_to_idx = _scan_for_markers(proteins, Path(hmm_path_str))
by_category: dict[str, list[str]] = {cat: [] for cat in CATEGORIES}
for cat in CATEGORIES:
idxs: set[int] = set()
for marker, protein_ids in marker_to_idx.items():
if MARKER_TO_CATEGORY.get(marker) == cat:
idxs.update(protein_ids)
ranked = sorted(idxs, key=lambda i: len(proteins[i]))
kept = ranked[:max_per_cat]
by_category[cat] = [proteins[i][:MAX_PROTEIN_LEN] for i in kept]
return {
"bacdive_id": int(record_id),
"genome_accession": genome_accession,
"by_category": by_category,
"category_counts": {cat: len(by_category[cat]) for cat in CATEGORIES},
}
def _load_done(path: Path) -> tuple[set[int], set[str]]:
done_ids: set[int] = set()
done_accessions: set[str] = set()
if not path.exists():
return done_ids, done_accessions
with open(path) as fh:
for line in fh:
try:
row = json.loads(line)
done_ids.add(int(row["bacdive_id"]))
done_accessions.add(str(row["genome_accession"]))
except Exception:
continue
return done_ids, done_accessions
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input-path", type=Path, default=Path("data/gtdb_candidates.parquet"))
parser.add_argument("--out-path", type=Path, default=Path("data/uncultured_marker_sequences.jsonl"))
parser.add_argument("--id-col", default="")
parser.add_argument("--accession-col", default="genome_accession")
parser.add_argument("--fetch-accession-col", default="ncbi_assembly_accession_versioned")
parser.add_argument("--hmm-path", type=Path, default=Path("data/markers/unified/unified_markers.hmm"))
parser.add_argument("--limit", type=int, default=0)
parser.add_argument("--workers", type=int, default=4)
parser.add_argument("--max-per-cat", type=int, default=16)
return parser.parse_args()
def main() -> None:
args = parse_args()
source = pd.read_parquet(args.input_path)
if args.accession_col not in source.columns:
raise SystemExit(f"Missing accession column: {args.accession_col}")
if args.fetch_accession_col and args.fetch_accession_col not in source.columns:
raise SystemExit(f"Missing fetch accession column: {args.fetch_accession_col}")
ready = source[source[args.accession_col].notna()].copy().reset_index(drop=True)
if args.id_col and args.id_col in ready.columns:
ready["_record_id"] = ready[args.id_col].astype(int)
else:
ready["_record_id"] = ready.index + 1_000_000_000
ready["_genome_accession"] = ready[args.accession_col].astype(str)
if args.fetch_accession_col:
ready["_fetch_accession"] = ready[args.fetch_accession_col].fillna(ready[args.accession_col]).astype(str)
else:
ready["_fetch_accession"] = ready["_genome_accession"]
done_ids, done_accessions = _load_done(args.out_path)
pending = ready[
~ready["_record_id"].isin(done_ids)
& ~ready["_genome_accession"].isin(done_accessions)
]
if args.limit:
pending = pending.head(args.limit)
tasks = [
(
int(row["_record_id"]),
str(row["_genome_accession"]),
str(row["_fetch_accession"]),
args.max_per_cat,
str(args.hmm_path),
)
for row in pending[["_record_id", "_genome_accession", "_fetch_accession"]].to_dict("records")
]
args.out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Marker-sequence local extract: {len(tasks):,} pending ({len(done_accessions):,} cached)")
print(f"input_path={args.input_path}")
print(f"out_path={args.out_path}")
print(f"workers={args.workers} max_per_cat={args.max_per_cat}")
if not tasks:
return
n_ok = 0
n_fail = 0
with open(args.out_path, "a") as log, ProcessPoolExecutor(max_workers=args.workers) as pool:
futures = {pool.submit(_extract_one, task): task for task in tasks}
for completed, future in enumerate(as_completed(futures), start=1):
try:
result = future.result()
except Exception:
result = None
if result is None:
n_fail += 1
else:
log.write(json.dumps(result) + "\n")
log.flush()
n_ok += 1
if completed % 25 == 0 or completed == len(tasks):
print(f" {completed:,}/{len(tasks):,} complete ok={n_ok:,} fail={n_fail:,}", flush=True)
print(f"Finished. {n_ok:,} succeeded, {n_fail:,} failed.")
if __name__ == "__main__":
main()