"""Per-marker ESM-2 t30 embedding service — runs on a Cerebrium L4 GPU container. embed_genome(bacdive_id, accession) → {"ok": bool, "row": {pme__: float, ...}} or {"ok": False, "reason": ...}. The unified-marker HMM library is baked into the image. Each replica loads ESM-2 + HMMs once at startup, then serves multiple genomes from the warm container. """ from __future__ import annotations import io import os import time import zipfile from typing import Any import numpy as np import pyhmmer import pyhmmer.easel import pyhmmer.plan7 import pyrodigal import requests import torch from transformers import AutoModel, AutoTokenizer DATASETS_URL = "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/accession/{acc}/download" VERSION_FALLBACKS = (".1", ".2", ".3", ".4") EMPTY_ZIP_BYTES = 2_000 EVALUE_THRESHOLD = 1e-5 MARKER_TO_CATEGORY: dict[str, str] = { "Hsp70_DnaK": "temperature", "Hsp90": "temperature", "Cpn60_GroEL": "temperature", "Hsp20": "temperature", "CSD_cold_shock": "temperature", "TGS_thermosome": "temperature", "ATP_synth_alphabeta": "ph", "ATP_synth_alphabeta_C": "ph", "ATP_synth_F0_B": "ph", "NhaA_Na_H_exch": "ph", "NhaB_Na_H_exch": "ph", "Pyridoxal_decarbox": "ph", "MotA_TolQ_ExbB": "ph", "V_ATPase_subH_N": "ph", "COX1_aerobic": "oxygen", "COX2_TM_aerobic": "oxygen", "COX2_periplasm_aero": "oxygen", "Cyt_CBB3_microaero": "oxygen", "Rieske_2Fe2S": "oxygen", "Catalase": "oxygen", "SOD_FeMn": "oxygen", "SOD_CuZn": "oxygen", "FeFe_hyd_anaerobic": "oxygen", "NiFe_hyd_anaerobic": "oxygen", "FAD_binding_FrdA": "oxygen", "Fer4_FeS_4Fe4S": "oxygen", "KdpD_osmosensor": "salt", "TrkH_K_channel": "salt", "BCCT_compatible": "salt", "BPD_transp_1": "salt", "EctC_ectoine_synth": "salt", "Bact_rhodopsin": "salt", "TP_methylase_B12": "vitamin", "Peripla_BP_2": "vitamin", "THF_DHG_CYH_folate": "vitamin", "FolB_folate": "vitamin", "PdxJ_pyridoxine": "vitamin", "DHBP_riboflavin": "vitamin", "NifH_nitrogenase": "nitrogen", "NifDK_nitrogenase": "nitrogen", "NIR_SIR_ferredoxin": "nitrogen", "RuBisCO_large_form1": "carbon", "RuBisCO_small_form1": "carbon", "Alpha_amylase": "carbon", "Cellulase_GH5": "carbon", "CBM_cellulose": "carbon", "Molybdopterin_OR": "special", "UvrD_helicase_C": "special", } CATEGORIES = ["temperature", "ph", "oxygen", "salt", "vitamin", "nitrogen", "carbon", "special"] _model_name = os.environ.get("ESM2_MODEL", "facebook/esm2_t30_150M_UR50D") _batch_size = int(os.environ.get("ESM2_BATCH_SIZE", "16")) _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _dtype = torch.float16 if _device.type == "cuda" else torch.float32 print(f"[boot] loading {_model_name} on {_device} ({_dtype})", flush=True) _tokenizer = AutoTokenizer.from_pretrained(_model_name) _model = AutoModel.from_pretrained(_model_name, dtype=_dtype) _model.to(_device) _model.train(False) _embed_dim = _model.config.hidden_size _alphabet = pyhmmer.easel.Alphabet.amino() with pyhmmer.plan7.HMMFile("/cortex/app/markers.hmm") as _fh: _hmms = list(_fh) _ncbi_key = os.environ.get("NCBI_API_KEY") print(f"[boot] loaded {len(_hmms)} marker HMMs, embed_dim={_embed_dim}, " f"ncbi_key={'yes' if _ncbi_key else 'no'}", flush=True) def _has_version(acc: str) -> bool: return "." in acc and acc.rsplit(".", 1)[-1].isdigit() def _candidates(acc: str) -> list[str]: return [acc] if _has_version(acc) else [acc + v for v in VERSION_FALLBACKS] def _fetch_fasta(acc: str) -> list[tuple[str, str]] | None: rate = 0.1 if _ncbi_key else 0.34 headers = {"Accept": "application/zip"} if _ncbi_key: headers["api-key"] = _ncbi_key params = {"include_annotation_type": "GENOME_FASTA"} for cand in _candidates(acc): zip_bytes: bytes | None = None for attempt in range(3): try: time.sleep(rate) resp = requests.get( DATASETS_URL.format(acc=cand), params=params, headers=headers, timeout=120, ) if resp.status_code == 404: break if resp.status_code in (429, 502, 503): time.sleep(2 ** attempt) continue resp.raise_for_status() except requests.RequestException: if attempt == 2: break time.sleep(2 ** attempt) continue if len(resp.content) < EMPTY_ZIP_BYTES: break zip_bytes = resp.content break if zip_bytes is None: continue try: with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: fna = [n for n in zf.namelist() if n.endswith(".fna")] if not fna: continue with zf.open(fna[0]) as src: raw = src.read() except zipfile.BadZipFile: continue return _parse_fasta(raw) return None def _parse_fasta(raw: bytes) -> list[tuple[str, str]]: contigs: list[tuple[str, str]] = [] cur: str | None = None chunks: list[str] = [] for line in raw.splitlines(): if not line: continue if line.startswith(b">"): if cur is not None: contigs.append((cur, "".join(chunks).upper())) cur = line[1:].decode("ascii", errors="replace").split()[0] chunks = [] else: chunks.append(line.decode("ascii", errors="replace")) if cur is not None: contigs.append((cur, "".join(chunks).upper())) return contigs def _predict_proteins(contigs: list[tuple[str, str]]) -> list[str]: encoded = [(n, s.encode("ascii")) for n, s in contigs] total_nt = sum(len(s) for _, s in encoded) if total_nt >= 20_000: finder = pyrodigal.GeneFinder(meta=False) try: finder.train(b"TTAATTAATTAA".join(s for _, s in encoded)) except Exception: finder = pyrodigal.GeneFinder(meta=True) else: finder = pyrodigal.GeneFinder(meta=True) proteins: list[str] = [] for _, s in encoded: for gene in finder.find_genes(s): proteins.append(gene.translate().rstrip("*")) return proteins def _embed_proteins(proteins: list[str]) -> np.ndarray: if not proteins: return np.zeros((0, _embed_dim), dtype=np.float32) out: list = [] for i in range(0, len(proteins), _batch_size): batch = proteins[i : i + _batch_size] enc = _tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=1024) enc = {k: v.to(_device) for k, v in enc.items()} with torch.inference_mode(): outs = _model(**enc) last_hidden = outs.last_hidden_state mask = enc["attention_mask"].unsqueeze(-1).to(last_hidden.dtype) pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) out.append(pooled.float().cpu().numpy()) return np.concatenate(out, axis=0) def _scan_markers(proteins: list[str]) -> dict[str, list[int]]: seqs = [] for i, prot in enumerate(proteins): if not prot: continue ts = pyhmmer.easel.TextSequence(name=f"p{i}".encode(), sequence=prot) seqs.append(ts.digitize(_alphabet)) result: dict[str, list[int]] = {name: [] for name in MARKER_TO_CATEGORY} if not seqs: return result for top_hits in pyhmmer.hmmer.hmmsearch(_hmms, seqs, E=EVALUE_THRESHOLD): raw = top_hits.query.name marker = raw.decode() if isinstance(raw, bytes) else raw if marker not in result: continue for hit in top_hits: if hit.evalue > EVALUE_THRESHOLD: continue name = hit.name.decode() if isinstance(hit.name, bytes) else hit.name if name.startswith("p"): try: result[marker].append(int(name[1:])) except ValueError: pass return result def embed_genome(bacdive_id: int, accession: str) -> dict[str, Any]: try: contigs = _fetch_fasta(accession) if not contigs: return {"ok": False, "reason": "fetch_empty", "bacdive_id": bacdive_id, "accession": accession} proteins = _predict_proteins(contigs) if not proteins: return {"ok": False, "reason": "no_proteins", "bacdive_id": bacdive_id, "accession": accession} marker_idx = _scan_markers(proteins) hit_indices = sorted({i for ids in marker_idx.values() for i in ids}) row: dict[str, Any] = { "bacdive_id": int(bacdive_id), "genome_accession": accession, "pme_marker_proteins_total": len(hit_indices), } if not hit_indices: for cat in CATEGORIES: row[f"pme_{cat}_n"] = 0 for d in range(_embed_dim): row[f"pme_{cat}_{d}"] = 0.0 return {"ok": True, "row": row} hit_proteins = [proteins[i] for i in hit_indices] hit_matrix = _embed_proteins(hit_proteins) gi_to_ri = {gi: ri for ri, gi in enumerate(hit_indices)} for cat in CATEGORIES: idxs: set[int] = set() for marker, gis in marker_idx.items(): if MARKER_TO_CATEGORY.get(marker) == cat: idxs.update(gis) row[f"pme_{cat}_n"] = len(idxs) if idxs: rows = [gi_to_ri[gi] for gi in idxs if gi in gi_to_ri] if rows: cat_mean = hit_matrix[rows].mean(axis=0).astype(np.float32) for d, v in enumerate(cat_mean): row[f"pme_{cat}_{d}"] = float(v) continue for d in range(_embed_dim): row[f"pme_{cat}_{d}"] = 0.0 return {"ok": True, "row": row} except Exception as exc: return {"ok": False, "reason": f"{type(exc).__name__}: {exc}", "bacdive_id": bacdive_id, "accession": accession}