microbe-model / scripts /36_extract_marker_sequences.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Extract HMM-gated protein sequences per genome for LoRA fine-tuning.
This is a sibling to scripts/modal_per_marker_embed.py — same fetch+pyrodigal+pyhmmer
pipeline — but instead of mean-pooling ESM-2 embeddings, it emits the raw protein
sequences themselves, grouped by phenotype category. Those sequences become the input
to scripts/37_train_lora.py for end-to-end LoRA fine-tuning.
Per-genome output (one JSONL line):
{
"bacdive_id": 482,
"genome_accession": "GCF_000005845.2",
"by_category": {
"oxygen": ["MLDF...", "MFKK...", ...],
"temperature": ["MAKH...", ...],
...
},
"category_counts": {"oxygen": 12, "temperature": 8, ...}
}
CPU-only (skips ESM-2). With 16 concurrent Modal containers each with a unique IP
(bypassing NCBI's 3 req/s per-IP limit), ~22K genomes should finish in ~30-60 minutes
of wall time for ~$2-5 of Modal compute.
Usage:
modal run scripts/36_extract_marker_sequences.py --limit 50
modal run scripts/36_extract_marker_sequences.py --max-per-cat 16
modal run scripts/36_extract_marker_sequences.py \
--input-path data/gtdb_candidates.parquet \
--id-col "" \
--accession-col genome_accession \
--fetch-accession-col ncbi_assembly_accession_versioned \
--require-label 0 \
--out-path data/uncultured_marker_sequences.jsonl
"""
from __future__ import annotations
import json
from pathlib import Path
import modal
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install([
"pyrodigal>=3.5",
"pyhmmer>=0.12",
"requests>=2.32",
])
.add_local_file("data/markers/unified/unified_markers.hmm", "/root/markers.hmm")
)
app = modal.App("microbe-extract-marker-seqs", image=image)
DATASETS_URL = "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{acc}/download"
VERSION_FALLBACKS = (".1", ".2", ".3", ".4")
EMPTY_ZIP_BYTES = 2_000
MARKER_TO_CATEGORY: dict[str, str] = {
# temperature
"Hsp70_DnaK": "temperature", "Hsp90": "temperature", "Cpn60_GroEL": "temperature",
"Hsp20": "temperature", "CSD_cold_shock": "temperature", "TGS_thermosome": "temperature",
# pH
"ATP_synth_alphabeta": "ph", "ATP_synth_alphabeta_C": "ph", "ATP_synth_F0_B": "ph",
"NhaA_Na_H_exch": "ph", "NhaB_Na_H_exch": "ph", "Pyridoxal_decarbox": "ph",
"MotA_TolQ_ExbB": "ph", "V_ATPase_subH_N": "ph",
# oxygen
"COX1_aerobic": "oxygen", "COX2_TM_aerobic": "oxygen", "COX2_periplasm_aero": "oxygen",
"Cyt_CBB3_microaero": "oxygen", "Rieske_2Fe2S": "oxygen", "Catalase": "oxygen",
"SOD_FeMn": "oxygen", "SOD_CuZn": "oxygen", "FeFe_hyd_anaerobic": "oxygen",
"NiFe_hyd_anaerobic": "oxygen", "FAD_binding_FrdA": "oxygen", "Fer4_FeS_4Fe4S": "oxygen",
# salt
"KdpD_osmosensor": "salt", "TrkH_K_channel": "salt", "BCCT_compatible": "salt",
"BPD_transp_1": "salt", "EctC_ectoine_synth": "salt", "Bact_rhodopsin": "salt",
# vitamin
"TP_methylase_B12": "vitamin", "Peripla_BP_2": "vitamin", "THF_DHG_CYH_folate": "vitamin",
"FolB_folate": "vitamin", "PdxJ_pyridoxine": "vitamin", "DHBP_riboflavin": "vitamin",
# nitrogen
"NifH_nitrogenase": "nitrogen", "NifDK_nitrogenase": "nitrogen",
"NIR_SIR_ferredoxin": "nitrogen",
# carbon
"RuBisCO_large_form1": "carbon", "RuBisCO_small_form1": "carbon",
"Alpha_amylase": "carbon", "Cellulase_GH5": "carbon", "CBM_cellulose": "carbon",
# special
"Molybdopterin_OR": "special", "UvrD_helicase_C": "special",
}
CATEGORIES = ["temperature", "ph", "oxygen", "salt", "vitamin", "nitrogen", "carbon", "special"]
EVALUE_THRESHOLD = 1e-5
MAX_PROTEIN_LEN = 1022 # ESM-2 context window minus special tokens
def _has_version(accession: str) -> bool:
return "." in accession and accession.rsplit(".", 1)[-1].isdigit()
def _candidate_accessions(accession: str) -> list[str]:
if _has_version(accession):
return [accession]
return [accession + v for v in VERSION_FALLBACKS]
def _fetch_fasta_bytes(accession: str, ncbi_key: str | None) -> list[tuple[str, str]] | None:
import io
import zipfile
import requests
headers = {"api-key": ncbi_key} if ncbi_key else {}
for cand in _candidate_accessions(accession):
url = DATASETS_URL.format(acc=cand)
try:
resp = requests.get(
url,
params={"include_annotation_type": "GENOME_FASTA"},
headers=headers,
timeout=120,
)
except requests.RequestException:
continue
if resp.status_code != 200 or len(resp.content) < EMPTY_ZIP_BYTES:
continue
try:
with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
fasta_names = [n for n in zf.namelist() if n.endswith(".fna")]
if not fasta_names:
continue
with zf.open(fasta_names[0]) as src:
raw = src.read()
except zipfile.BadZipFile:
continue
return _parse_fasta(raw)
return None
def _parse_fasta(raw: bytes) -> list[tuple[str, str]]:
contigs: list[tuple[str, str]] = []
name = None
chunks: list[str] = []
for line in raw.decode("utf-8", errors="ignore").splitlines():
if line.startswith(">"):
if name is not None:
contigs.append((name, "".join(chunks)))
name = line[1:].split()[0]
chunks = []
else:
chunks.append(line.strip())
if name is not None:
contigs.append((name, "".join(chunks)))
return contigs
def _predict_proteins(contigs: list[tuple[str, str]]) -> list[str]:
import pyrodigal
if not contigs:
return []
total = sum(len(s) for _, s in contigs)
meta = total < 100_000
orf = pyrodigal.GeneFinder(meta=meta)
if not meta:
orf.train(*[s.encode() for _, s in contigs])
proteins: list[str] = []
for _, seq in contigs:
for gene in orf.find_genes(seq.encode()):
proteins.append(gene.translate().rstrip("*"))
return proteins
@app.cls(
cpu=2.0,
memory=2048,
timeout=3600 * 4,
secrets=[modal.Secret.from_name("ncbi-key", required_keys=["NCBI_API_KEY"])],
max_containers=16,
scaledown_window=60,
)
class MarkerSeqExtractor:
@modal.enter()
def setup(self):
import os
import pyhmmer
import pyhmmer.easel
import pyhmmer.plan7
self.pyhmmer = pyhmmer
self.alphabet = pyhmmer.easel.Alphabet.amino()
with pyhmmer.plan7.HMMFile("/root/markers.hmm") as fh:
self.hmms = list(fh)
print(f"[setup] loaded {len(self.hmms)} marker HMMs", flush=True)
self.ncbi_key = os.environ.get("NCBI_API_KEY")
self.max_per_cat = int(os.environ.get("MAX_PER_CATEGORY", "16"))
def _scan_for_markers(self, proteins: list[str]) -> dict[str, list[int]]:
seqs = []
for i, prot in enumerate(proteins):
if not prot:
continue
ts = self.pyhmmer.easel.TextSequence(name=f"p{i}".encode(), sequence=prot)
seqs.append(ts.digitize(self.alphabet))
result: dict[str, list[int]] = {name: [] for name in MARKER_TO_CATEGORY}
if not seqs:
return result
for top_hits in self.pyhmmer.hmmer.hmmsearch(self.hmms, seqs, E=EVALUE_THRESHOLD):
raw = top_hits.query.name
marker = raw.decode() if isinstance(raw, bytes) else raw
if marker not in result:
continue
for hit in top_hits:
if hit.evalue > EVALUE_THRESHOLD:
continue
hit_name = hit.name.decode() if isinstance(hit.name, bytes) else hit.name
if hit_name.startswith("p"):
try:
result[marker].append(int(hit_name[1:]))
except ValueError:
pass
return result
@modal.method()
def extract_genome(
self,
record_id: int,
genome_accession: str,
fetch_accession: str | None = None,
) -> dict | None:
try:
contigs = _fetch_fasta_bytes(fetch_accession or genome_accession, self.ncbi_key)
if not contigs:
return None
proteins = _predict_proteins(contigs)
if not proteins:
return None
marker_to_idx = self._scan_for_markers(proteins)
by_category: dict[str, list[str]] = {c: [] for c in CATEGORIES}
for cat in CATEGORIES:
# Gather unique protein indices for this category
idxs: set[int] = set()
for marker, gis in marker_to_idx.items():
if MARKER_TO_CATEGORY.get(marker) == cat:
idxs.update(gis)
# Take top-K shortest proteins (preference for unique/specific hits)
ranked = sorted(idxs, key=lambda i: len(proteins[i]))
kept = ranked[: self.max_per_cat]
by_category[cat] = [proteins[i][:MAX_PROTEIN_LEN] for i in kept]
return {
"bacdive_id": int(record_id),
"genome_accession": genome_accession,
"by_category": by_category,
"category_counts": {c: len(by_category[c]) for c in CATEGORIES},
}
except Exception as exc:
print(f" skip {genome_accession}: {type(exc).__name__}: {exc}", flush=True)
return None
@app.local_entrypoint()
def main(
out_path: str = "data/marker_sequences.jsonl",
input_path: str = "data/bacdive_phenotypes.parquet",
id_col: str = "bacdive_id",
accession_col: str = "genome_accession",
fetch_accession_col: str = "",
require_label: int = 1,
limit: int = 0,
max_per_cat: int = 16,
):
"""Dispatch genomes to Modal containers; stream sequences to local JSONL."""
import pandas as pd
source = pd.read_parquet(input_path)
if accession_col not in source.columns:
raise ValueError(f"{input_path} is missing accession column: {accession_col}")
ready = source[source[accession_col].notna()].copy()
if require_label:
label_cols = ["optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"]
present_label_cols = [col for col in label_cols if col in ready.columns]
if not present_label_cols:
raise ValueError(
f"require_label=1 but {input_path} has none of these columns: {label_cols}"
)
ready = ready[ready[present_label_cols].notna().any(axis=1)].copy()
if id_col and id_col in ready.columns:
ready["_marker_seq_id"] = ready[id_col].astype(int)
else:
ready = ready.reset_index(drop=True)
ready["_marker_seq_id"] = ready.index + 1_000_000_000
ready["_genome_accession"] = ready[accession_col].astype(str)
if fetch_accession_col and fetch_accession_col in ready.columns:
ready["_fetch_accession"] = ready[fetch_accession_col].fillna(ready[accession_col]).astype(str)
else:
ready["_fetch_accession"] = ready["_genome_accession"]
out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
done: set[int] = set()
done_accessions: set[str] = set()
if out.exists():
with open(out) as fh:
for line in fh:
try:
row = json.loads(line)
done.add(int(row["bacdive_id"]))
if row.get("genome_accession"):
done_accessions.add(str(row["genome_accession"]))
except Exception:
continue
pending = ready[
~ready["_marker_seq_id"].isin(done)
& ~ready["_genome_accession"].isin(done_accessions)
]
if limit:
pending = pending.head(limit)
tasks = list(zip(
pending["_marker_seq_id"],
pending["_genome_accession"],
pending["_fetch_accession"],
strict=True,
))
print(f"Marker-sequence extract: {len(tasks):,} genomes pending ({len(done):,} cached)")
print(f"input_path={input_path}")
print(f"accession_col={accession_col} fetch_accession_col={fetch_accession_col or accession_col}")
print(f"max_per_cat={max_per_cat}")
if not tasks:
return
config_secret = modal.Secret.from_dict({"MAX_PER_CATEGORY": str(max_per_cat)})
extractor = MarkerSeqExtractor.with_options(
secrets=[
modal.Secret.from_name("ncbi-key", required_keys=["NCBI_API_KEY"]),
config_secret,
],
)()
n_ok = n_fail = 0
with open(out, "a") as log:
for result in extractor.extract_genome.starmap(tasks, return_exceptions=True):
if isinstance(result, Exception) or result is None:
n_fail += 1
continue
log.write(json.dumps(result) + "\n")
log.flush()
n_ok += 1
if n_ok % 100 == 0:
print(f" {n_ok:,} ok / {n_fail:,} fail")
print(f"\nFinished. {n_ok:,} succeeded, {n_fail:,} failed.")
print(f"Streamed to {out}")