mgnify-evo2-probes / code /modal /evo2_inference.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""
Modal scaffold for self-hosted Evo 2-7B inference.
Uses Arc Institute's official Dockerfile (from the ArcInstitute/evo2 repo),
cached HF weights in a Modal Volume, and the existing "huggingface" Modal
Secret (same one used by ~/AIMO3-TIR/compute/modal/).
NO dependency on NVIDIA NIM / NGC / BioNeMo — weights come from HF, the
container is built from Arc's Dockerfile.
ONE-TIME SETUP
1. Authenticate Modal (already done): modal token new
2. HF secret (already done via AIMO3): modal secret create huggingface HF_TOKEN=...
3. Clone Arc's repo locally: git clone https://github.com/ArcInstitute/evo2.git ~/evo2_repo
4. Upload MGnify data (optional, one-time): modal volume create mgnify-data
modal volume put mgnify-data /home/ror25cal/MGnify/data/ /
USAGE
modal run modal/evo2_inference.py::main --seq "ACGT..." --layer 26
"""
import os
from pathlib import Path
import modal
# --- paths / names --------------------------------------------------------
APP_NAME = "mgnify-evo2-7b"
VOL_WEIGHTS = "evo2-7b-weights" # HF cache (weights persist here)
VOL_DATA = "mgnify-data" # MGnify FASTAs (optional)
TARGET_LAYER = "blocks.26.mlp.l3" # adjust as needed
# --- image -----------------------------------------------------------------
# Matches Arc Institute's official Dockerfile (nvcr.io/nvidia/pytorch:25.04-py3
# + pip install evo2), translated to Modal's native Image API because Modal
# doesn't support the `WORKDIR` directive from the raw Dockerfile.
image = (
modal.Image.from_registry(
"nvcr.io/nvidia/pytorch:25.04-py3",
add_python=None, # base image already has python
)
.apt_install("git", "python3-pip", "python3-tomli")
.pip_install("evo2") # pulls flash-attn + vtx + huggingface_hub as transitive deps
)
# --- persistent storage ---------------------------------------------------
weights_vol = modal.Volume.from_name(VOL_WEIGHTS, create_if_missing=True)
data_vol = modal.Volume.from_name(VOL_DATA, create_if_missing=True)
app = modal.App(APP_NAME)
@app.function(
image=image,
gpu="H100", # compute cap 9.0 — needed for Evo2's FP8 kernels (A100 is cc 8.0, fails)
volumes={
"/root/.cache/huggingface": weights_vol, # HF will cache evo2_7b here, persists
"/data": data_vol, # MGnify FASTAs if you uploaded them
},
secrets=[modal.Secret.from_name("huggingface")], # sets HF_TOKEN env var
timeout=3600,
)
def embed(sequences: list[tuple[str, str]], layers: list[str] | None = None) -> dict:
"""
Run Evo 2-7B forward pass on a batch of (name, sequence) pairs and return
per-layer embeddings. First call downloads weights into the HF-cache Volume;
subsequent calls skip the download.
sequences: list of (name, DNA_string)
layers: list of layer names, e.g. ["blocks.26.mlp.l3"]. Default: [TARGET_LAYER]
returns: {name: {layer: np.ndarray of shape [seq_len, hidden_dim]}}
"""
import numpy as np
import torch
from evo2 import Evo2
layers = layers or [TARGET_LAYER]
model = Evo2("evo2_7b") # loads from /root/.cache/huggingface
out = {}
for name, seq in sequences:
# Arc's canonical API: tokenize → int tensor → batch dim → cuda → model(..., return_embeddings=True)
input_ids = torch.tensor(
model.tokenizer.tokenize(seq), dtype=torch.int,
).unsqueeze(0).to("cuda:0")
_, embeddings = model(input_ids, return_embeddings=True, layer_names=layers)
out[name] = {lyr: np.asarray(embeddings[lyr].squeeze(0).float().cpu()) for lyr in layers}
return out
@app.function(
image=image,
gpu="H100",
volumes={"/root/.cache/huggingface": weights_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=1800,
)
def embed_and_sae(sequence: str, topk: int = 64) -> dict:
"""
Evo 2-7B-262k forward (hook at whole block-26 output) + Goodfire BatchTopK SAE.
Follows Arc Institute's reference notebook:
notebooks/sparse_autoencoder/sparse_autoencoder.ipynb
Key differences from teammate's simple-ReLU SAE class:
- model = evo2_7b_262k (the long-context variant Goodfire's SAE was trained on)
- layer name = 'blocks-26' (whole block output, not blocks.26.mlp.l3)
- BatchTopK=64 applied at encode — otherwise features are 4-5x too dense
"""
import numpy as np
import torch
from evo2 import Evo2
from huggingface_hub import hf_hub_download
SAE_LAYER = "blocks-26"
D_HIDDEN = 4096
D_SAE = D_HIDDEN * 8
K = 64
# --- load Evo 2 7B (262k context) and register a caching hook at block 26 output ---
evo2 = Evo2("evo2_7b_262k")
device = next(evo2.model.parameters()).device
# Walk the module tree like ModelScope does; find 'blocks-26' (one module per block)
module_dict = {}
def recurse(m, prefix=""):
for name, child in m.named_children():
module_dict[prefix + name] = child
recurse(child, prefix + name + "-")
recurse(evo2.model)
target_module = module_dict[SAE_LAYER]
cache = {}
def hook_fn(module, inp, out):
acts = out[0] if isinstance(out, tuple) else out
cache["acts"] = acts.detach()
handle = target_module.register_forward_hook(hook_fn)
try:
input_ids = torch.tensor(evo2.tokenizer.tokenize(sequence), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
finally:
handle.remove()
acts = cache["acts"][0] # [seq_len, 4096] bf16
# --- Goodfire BatchTopK SAE ---
sae_path = hf_hub_download(
repo_id="Goodfire/Evo-2-Layer-26-Mixed",
filename="sae-layer26-mixed-expansion_8-k_64.pt",
)
sae_sd = torch.load(sae_path, map_location=device, weights_only=True)
# Strip the torch.compile prefix that Goodfire saved with
sae_sd = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_sd.items()}
W = sae_sd["W"].to(device=device, dtype=acts.dtype) # [d_hidden, d_sae]
b_enc = sae_sd["b_enc"].to(device=device, dtype=acts.dtype)
# Encode with BatchTopK (Arc's reference): top-K across the WHOLE batch × features, not per-token
pre = torch.relu(acts @ W + b_enc) # [seq_len, d_sae]
flat = pre.flatten()
numel = K * pre.shape[0]
topk_res = torch.topk(flat, numel, dim=-1)
latents_flat = torch.zeros_like(flat).scatter(-1, topk_res.indices, topk_res.values)
latents = latents_flat.reshape(pre.shape) # [seq_len, d_sae] sparse
# Return only the nonzero features per position
topk_vals_per_pos, topk_idx_per_pos = latents.topk(topk, dim=1)
active_per_pos = (latents > 0).sum(dim=1)
return {
"seq_len": int(acts.shape[0]),
"d_model": int(acts.shape[1]),
"d_sae": int(latents.shape[1]),
"model_name": "evo2_7b_262k",
"sae_layer": SAE_LAYER,
"topk_k": K,
"topk_values": topk_vals_per_pos.float().cpu().numpy().astype(np.float32).tolist(),
"topk_indices": topk_idx_per_pos.cpu().numpy().astype(np.int32).tolist(),
"active_features_per_position": active_per_pos.cpu().numpy().astype(np.int32).tolist(),
"activation_l2_norm": torch.linalg.norm(acts.float(), dim=1).cpu().numpy().astype(np.float32).tolist(),
}
# --- Embedding output volume ---
embeddings_vol = modal.Volume.from_name("mgnify-embeddings", create_if_missing=True)
embeddings_targeted_vol = modal.Volume.from_name("mgnify-embeddings-targeted", create_if_missing=True)
jsonl_vol = modal.Volume.from_name("mgnify-targeted-jsonl", create_if_missing=True)
def _get_models():
"""Module-level lazy cache. Caches Evo2 + SAE weights + module_dict, but NOT hooks
(hooks are re-registered per call to avoid stale-state OOM in container reuse)."""
try:
if _CACHED_EVO2 is not None:
return _CACHED_EVO2, _CACHED_SAE_W, _CACHED_SAE_BENC, _CACHED_DEVICE, _CACHED_MODULE_DICT
except NameError:
pass
import torch
from evo2 import Evo2
from huggingface_hub import hf_hub_download
print("[container] loading Evo2 7B-262k (once per container)")
evo2 = Evo2("evo2_7b_262k")
device = next(evo2.model.parameters()).device
module_dict = {}
def recurse(m, prefix=""):
for n, c in m.named_children():
module_dict[prefix + n] = c
recurse(c, prefix + n + "-")
recurse(evo2.model)
sae_path = hf_hub_download(
repo_id="Goodfire/Evo-2-Layer-26-Mixed",
filename="sae-layer26-mixed-expansion_8-k_64.pt",
)
sae_sd = torch.load(sae_path, map_location=device, weights_only=True)
sae_sd = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_sd.items()}
W_sae = sae_sd["W"].to(device=device).to(torch.bfloat16)
b_enc = sae_sd["b_enc"].to(device=device).to(torch.bfloat16)
globals()["_CACHED_EVO2"] = evo2
globals()["_CACHED_SAE_W"] = W_sae
globals()["_CACHED_SAE_BENC"] = b_enc
globals()["_CACHED_DEVICE"] = device
globals()["_CACHED_MODULE_DICT"] = module_dict
return evo2, W_sae, b_enc, device, module_dict
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings": embeddings_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=3600,
)
def embed_full(
mag_id: str,
contig_id: str,
sequence: str,
pool_size: int = 1000,
chunk_size: int = 64000, # safe under Vortex FFT int32 indexing limit
overlap: int = 0, # bp of overlap between consecutive chunks
) -> dict:
"""
Process one contig of any length, chunking as needed.
Saves one /embeddings/{mag_id}/{contig_id}_{N}.npz per chunk.
Reuses Evo2/SAE models across calls within the same Modal container.
"""
import numpy as np
import torch
import os
K = 64
layer_names = [f"blocks-{i}" for i in range(32)]
evo2, W_sae, b_enc, device, module_dict = _get_models()
# Register hooks per-call so each contig starts fresh (avoids OOM accumulation)
cache: dict = {}
def make_hook(name):
def hook(module, inp, out):
cache[name] = (out[0] if isinstance(out, tuple) else out).detach()
return hook
handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names]
contig_len = len(sequence)
step = chunk_size - overlap
if step <= 0:
raise ValueError(f"overlap ({overlap}) must be < chunk_size ({chunk_size})")
# Enumerate chunks: list of (chunk_start_bp, chunk_end_bp)
starts = list(range(0, contig_len, step))
chunks = [(s, min(s + chunk_size, contig_len)) for s in starts if s < contig_len]
n_chunks = len(chunks)
print(f"[{mag_id}/{contig_id}] contig_len={contig_len:,}, {n_chunks} chunks of ~{chunk_size:,} bp (overlap={overlap})")
out_dir = f"/embeddings/{mag_id}"
os.makedirs(out_dir, exist_ok=True)
saved = []
# Skip chunks that already exist (idempotent re-runs)
for ci, (cstart, cend) in list(enumerate(chunks)):
existing = f"{out_dir}/{contig_id}_{ci}.npz"
if os.path.exists(existing):
saved.append({"path": existing, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": os.path.getsize(existing)/1e6, "skipped": True})
if len(saved) == len(chunks):
print(f"[{mag_id}/{contig_id}] all {n_chunks} chunks already on volume, skipping")
return {"mag_id": mag_id, "contig_id": contig_id, "contig_len": contig_len, "n_chunks": n_chunks, "chunks": saved, "total_size_mb": sum(s['size_mb'] for s in saved), "all_cached": True}
saved = [s for s in saved if False] # reset; we re-collect below in order
try:
for ci, (cstart, cend) in enumerate(chunks):
existing_path = f"{out_dir}/{contig_id}_{ci}.npz"
if os.path.exists(existing_path):
saved.append({"path": existing_path, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": os.path.getsize(existing_path)/1e6, "cached": True})
continue
chunk_seq = sequence[cstart:cend]
chunk_len = cend - cstart
# --- forward pass on this chunk ---
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(chunk_seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache["blocks-0"].shape[1]
print(f" [{ci+1}/{n_chunks}] bp {cstart:,}-{cend:,} seq_len={seq_len}")
# --- BatchTopK SAE on layer 26 FIRST (so we can free GPU activations after) ---
acts26 = cache["blocks-26"][0]
pre = torch.relu(acts26 @ W_sae + b_enc)
flat = pre.flatten()
numel = K * pre.shape[0]
tk = torch.topk(flat, numel, dim=-1)
sparse_flat = torch.zeros_like(flat).scatter(-1, tk.indices, tk.values)
latents = sparse_flat.reshape(pre.shape)
top_v, top_i = latents.topk(K, dim=1)
top_i_cpu = top_i.cpu().numpy().astype(np.int32)
top_v_cpu = top_v.float().cpu().numpy().astype(np.float16)
# Free SAE intermediates immediately
del acts26, pre, flat, sparse_flat, latents, top_v, top_i, tk
torch.cuda.empty_cache()
# --- 1000-bp mean pooling, dropping each layer from cache as we go ---
n_full = seq_len // pool_size
n_windows = n_full + (1 if seq_len > n_full * pool_size else 0)
# Store as bf16 (same dynamic range as fp32, no overflow on late-layer activations).
# numpy lacks bf16 support, so we save the bf16 bit-pattern as uint16 and reinterpret on load.
layer_means = torch.zeros(32, n_windows, evo2.model.config.hidden_size, dtype=torch.bfloat16)
for i, ln in enumerate(layer_names):
acts = cache[ln][0]
full = acts[:n_full * pool_size].view(n_full, pool_size, -1).float().mean(dim=1) if n_full > 0 else acts.new_empty(0, acts.shape[-1])
if seq_len > n_full * pool_size:
tail = acts[n_full * pool_size:].float().mean(dim=0, keepdim=True)
pooled = torch.cat([full, tail], dim=0) if n_full > 0 else tail
else:
pooled = full
layer_means[i] = pooled.to(torch.bfloat16).cpu()
# IMPORTANT: drop the GPU activation for this layer to free memory
del acts, full, pooled
del cache[ln]
torch.cuda.empty_cache()
# Reinterpret bf16 buffer as uint16 for numpy storage (bit-exact, 2 bytes/value).
layer_means_uint16 = layer_means.view(torch.uint16).numpy()
# --- save chunk npz ---
out_path = f"{out_dir}/{contig_id}_{ci}.npz"
np.savez_compressed(
out_path,
layer_means_bf16=layer_means_uint16, # bit-pattern of bf16 in uint16 wrapper
layer_means_dtype="bfloat16", # marker: how to interpret layer_means_bf16
layer_names=np.array(layer_names),
pool_size=np.int32(pool_size),
sae_topk_indices=top_i_cpu,
sae_topk_values=top_v_cpu,
sae_layer="blocks-26",
seq_len=np.int32(seq_len),
chunk_start=np.int64(cstart),
chunk_end=np.int64(cend),
chunk_idx=np.int32(ci),
n_chunks=np.int32(n_chunks),
chunk_size=np.int32(chunk_size),
overlap=np.int32(overlap),
contig_id=contig_id,
contig_len=np.int64(contig_len),
mag_id=mag_id,
model_name="evo2_7b_262k",
)
file_size = os.path.getsize(out_path)
print(f" saved {file_size/1e6:.1f} MB -> {out_path}")
saved.append({"path": out_path, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": file_size/1e6})
# free for next chunk (acts26+SAE intermediates already deleted above)
del layer_means, layer_means_uint16, input_ids, top_i_cpu, top_v_cpu
torch.cuda.empty_cache()
finally:
for h in handles: h.remove() # always clean up hooks
cache.clear()
torch.cuda.empty_cache()
embeddings_vol.commit()
print(f"[{mag_id}/{contig_id}] all {n_chunks} chunks saved, committed to volume")
return {
"mag_id": mag_id, "contig_id": contig_id,
"contig_len": contig_len, "n_chunks": n_chunks,
"chunks": saved,
"total_size_mb": sum(s["size_mb"] for s in saved),
}
@app.local_entrypoint()
def full_mag_test(
mag_dir: str = "/home/ror25cal/MGnify/data/chicken-gut/species_catalogue/MGYG0003076/MGYG000307601/genome",
mag_id: str = "MGYG000307601",
contig_id: str = "MGYG000307601_21", # biggest contig (217 kb)
chunk_size: int = 64000,
overlap: int = 0,
):
"""Process one full contig, chunked. Saves /embeddings/{mag_id}/{contig_id}_{N}.npz per chunk."""
from pathlib import Path
fna = Path(f"{mag_dir}/{mag_id}.fna").read_text()
cur_name = None; seq_parts = []
records = {}
for line in fna.splitlines():
if line.startswith(">"):
if cur_name: records[cur_name] = "".join(seq_parts)
cur_name = line[1:].split()[0]; seq_parts = []
else:
seq_parts.append(line.strip())
if cur_name: records[cur_name] = "".join(seq_parts)
seq = records[contig_id]
n_expected = (len(seq) + chunk_size - overlap - 1) // (chunk_size - overlap)
print(f"[{mag_id}/{contig_id}] {len(seq):,} bp → ~{n_expected} chunks @ {chunk_size:,} bp (overlap={overlap})")
result = embed_full.remote(mag_id, contig_id, seq, pool_size=1000, chunk_size=chunk_size, overlap=overlap)
print(f"\n=== RESULT ===")
print(f" contig_len: {result['contig_len']:,}")
print(f" n_chunks: {result['n_chunks']}")
print(f" total_size: {result['total_size_mb']:.1f} MB")
print(f" chunks:")
for c in result["chunks"]:
print(f" [{c['chunk_idx']}] bp {c['bp_range'][0]:,}-{c['bp_range'][1]:,} {c['size_mb']:.1f} MB → {c['path']}")
@app.local_entrypoint()
def run_top50_skin_amr(
csv_path: str = "/home/ror25cal/MGnify/modal/top50_skin_amr.csv",
skin_dir: str = "/home/ror25cal/MGnify/data/human-skin/species_catalogue",
chunk_size: int = 64000,
overlap: int = 0,
min_contig_len: int = 5000, # skip tiny fragmented contigs — too short for meaningful gene context
):
"""
Process the top-50 human-skin MAGs (sorted by AMR-per-Mb density) end-to-end.
One Modal call per contig, parallelism via .map() (Modal default ~10x).
Container reuse + module-level model cache → model loads ~once per worker.
Output: /embeddings/{mag_id}/{contig_id}_{chunk_idx}.npz on the mgnify-embeddings Volume.
"""
import csv
from pathlib import Path
# Read top-50 MAG list
with open(csv_path) as f:
reader = csv.DictReader(f)
mag_rows = list(reader)
mag_ids = [r["mag_id"] for r in mag_rows]
print(f"loaded {len(mag_ids)} MAGs from {csv_path}")
# Build the work list: one tuple per contig (chunking happens inside embed_full)
work = [] # list of (mag_id, contig_id, sequence, pool_size, chunk_size, overlap)
total_bp = 0
for mag_id in mag_ids:
prefix = mag_id[:11]
fna_path = Path(skin_dir) / prefix / mag_id / "genome" / f"{mag_id}.fna"
if not fna_path.exists():
print(f" skipping {mag_id} — fna not found at {fna_path}")
continue
cur = None; parts = []
records = {}
for line in fna_path.read_text().splitlines():
if line.startswith(">"):
if cur: records[cur] = "".join(parts)
cur = line[1:].split()[0]; parts = []
else: parts.append(line.strip())
if cur: records[cur] = "".join(parts)
for cid, seq in records.items():
if len(seq) < min_contig_len:
continue
work.append((mag_id, cid, seq, 1000, chunk_size, overlap))
total_bp += len(seq)
n_chunks_estimate = sum((len(w[2]) + chunk_size - 1) // chunk_size for w in work)
# Empirical: ~$0.025 per 64kb chunk on H100 (forward + SAE + npz save)
cost_estimate = n_chunks_estimate * 0.025
print(f"\nwork queue: {len(work)} contigs (≥ {min_contig_len:,} bp), {total_bp:,} total bp, ~{n_chunks_estimate} chunks")
print(f"cost estimate: ~${cost_estimate:.0f}")
# Submit all calls; Modal's .map() default concurrency handles the parallelism
print(f"\nsubmitting to Modal...\n")
n_done = 0; n_chunks_done = 0; total_mb = 0.0
for result in embed_full.starmap(work, return_exceptions=True):
n_done += 1
if isinstance(result, Exception):
print(f" [{n_done}/{len(work)}] ERROR: {result}")
continue
n_chunks_done += result.get("n_chunks", 0)
total_mb += result.get("total_size_mb", 0.0)
cached_tag = " (all cached)" if result.get("all_cached") else ""
print(f" [{n_done}/{len(work)}] {result['mag_id']}/{result['contig_id']}: {result['n_chunks']} chunks, {result['total_size_mb']:.1f} MB{cached_tag}")
print(f"\n=== DONE ===")
print(f" contigs processed: {n_done}/{len(work)}")
print(f" total chunks: {n_chunks_done}")
print(f" total volume size: {total_mb:.1f} MB")
@app.local_entrypoint()
def main(seq: str = "ACGT" * 100, layer: str = TARGET_LAYER):
"""Smoke-test: single short sequence, one layer (just the Evo2 embed, no SAE)."""
import numpy as np
print(f"submitting 1 sequence of length {len(seq)} bp to {APP_NAME} @ {layer}")
results = embed.remote([("smoke-test", seq)], layers=[layer])
for name, by_layer in results.items():
for lyr, arr in by_layer.items():
print(f" {name} / {lyr}: shape={arr.shape} |x|={np.abs(arr).mean():.3e}")
@app.local_entrypoint()
def crispr_test(
region_json: str = "/home/ror25cal/MGnify/modal/crispr_test_region.json",
out_path: str = "/home/ror25cal/MGnify/modal/crispr_test_result.json",
):
"""Run the CRISPR-region sanity test: Evo2 + Goodfire SAE, save result locally."""
import json
import numpy as np
region = json.loads(open(region_json).read())
seq = region["sequence"]
labels = np.array(region["labels"], dtype=np.int8)
print(f"region: {region['mag']} / {region['contig']} "
f"bp {region['region_start']}-{region['region_end']} len={len(seq)}")
print(f"label distribution: "
f"{dict(zip(*np.unique(labels, return_counts=True)))} "
f"(0=bg, 1=CRISPR, 2=DR, 3=spacer, 4=flank)")
result = embed_and_sae.remote(seq, topk=64)
print(f"\ngot back:")
print(f" seq_len: {result['seq_len']}")
print(f" d_model (Evo2 hidden): {result['d_model']}")
print(f" d_sae (Goodfire dict): {result['d_sae']}")
active = np.array(result["active_features_per_position"])
print(f" active features/pos: median={int(np.median(active))} max={int(active.max())}")
# quick sanity: mean top-1 activation in CRISPR-labelled positions vs background
topk_vals = np.array(result["topk_values"])
crispr_mask = labels > 0
bg_mask = labels == 0
# align lengths (Evo2 may add EOS/pad tokens)
n = min(len(labels), result["seq_len"])
labels = labels[:n]; crispr_mask = crispr_mask[:n]; bg_mask = bg_mask[:n]
top1 = topk_vals[:n, 0]
if crispr_mask.any() and bg_mask.any():
print(f"\n top-1 SAE activation — CRISPR positions: mean={top1[crispr_mask].mean():.3f} N={crispr_mask.sum()}")
print(f" top-1 SAE activation — background: mean={top1[bg_mask].mean():.3f} N={bg_mask.sum()}")
print(f" ratio (CRISPR / bg): {top1[crispr_mask].mean() / max(top1[bg_mask].mean(), 1e-9):.2f}x")
# Save for downstream viz
result["labels"] = labels.tolist()
result["region_meta"] = {k: v for k, v in region.items() if k not in ("sequence", "labels")}
open(out_path, "w").write(json.dumps(result))
print(f"\nsaved result -> {out_path}")
# ============================================================================
# Targeted-region per-token × all-32-layer embed pipeline
# Reads JSONL records from mgnify-targeted-jsonl volume, writes per-region
# npz files (per-token activations across all 32 blocks, bf16-as-uint16) to
# mgnify-embeddings-targeted volume. NO SAE — raw activations only.
# ============================================================================
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings_targeted": embeddings_targeted_vol,
"/jsonl": jsonl_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=3600,
max_containers=16,
)
def embed_targeted_jsonl(jsonl_rel_path: str) -> dict:
"""
Process every record in one JSONL file (one MAG × one label).
Saves /embeddings_targeted/{label}/{mag_id}/{region_id}.npz containing
per-token activations across all 32 blocks (bf16 stored as uint16).
jsonl_rel_path: path inside the jsonl volume, e.g. "full/AMR/MGYG000516287.jsonl"
"""
import json
import os
import numpy as np
import torch
src_path = f"/jsonl/{jsonl_rel_path}"
if not os.path.exists(src_path):
return {"path": jsonl_rel_path, "error": "missing", "n_done": 0, "n_skipped": 0, "total_mb": 0.0}
with open(src_path) as f:
records = [json.loads(line) for line in f if line.strip()]
if not records:
return {"path": jsonl_rel_path, "n_done": 0, "n_skipped": 0, "total_mb": 0.0}
evo2, _, _, device, module_dict = _get_models()
layer_names = [f"blocks-{i}" for i in range(32)] # whole-block output, NOT blocks-{i}-mlp-l3
cache: dict = {}
def make_hook(name):
def hook(module, inp, out):
cache[name] = (out[0] if isinstance(out, tuple) else out).detach()
return hook
handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names]
n_done = 0
n_skipped = 0
total_mb = 0.0
try:
for rec in records:
label_folder = rec["label"] if rec["is_positive"] else "MISC"
mag_id = rec["mag_id"]
region_id = rec["region_id"]
out_dir = f"/embeddings_targeted/{label_folder}/{mag_id}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{region_id}.npz"
if os.path.exists(out_path):
n_skipped += 1
continue
seq = rec["sequence"]
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache["blocks-0"].shape[1]
hidden = evo2.model.config.hidden_size
# [32, seq_len, hidden] bf16 → uint16 bit-pattern for numpy storage
stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16)
for i, ln in enumerate(layer_names):
stack[i] = cache[ln][0].to(torch.bfloat16).cpu()
del cache[ln]
torch.cuda.empty_cache()
stack_uint16 = stack.view(torch.uint16).numpy()
meta = {k: v for k, v in rec.items() if k != "sequence"}
np.savez_compressed(
out_path,
per_token_layer_activations_bf16=stack_uint16,
per_token_layer_activations_dtype="bfloat16",
layer_names=np.array(layer_names),
seq_len=np.int32(seq_len),
hidden_size=np.int32(hidden),
model_name="evo2_7b_262k",
metadata_json=np.array(json.dumps(meta)),
)
sz = os.path.getsize(out_path)
total_mb += sz / 1e6
n_done += 1
print(f" [{label_folder}/{mag_id}/{region_id}] seq_len={seq_len} saved {sz/1e6:.1f} MB")
del stack, stack_uint16, input_ids
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
cache.clear()
torch.cuda.empty_cache()
embeddings_targeted_vol.commit()
return {"path": jsonl_rel_path, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
volumes={"/jsonl": jsonl_vol},
timeout=86400,
)
def orchestrate_targeted_full() -> dict:
"""
CPU-only orchestrator. Walks the JSONL volume, fans every per-MAG-per-label
file out to embed_targeted_jsonl via .map() and aggregates results.
Living on Modal means the run survives local-process exit (use --detach).
"""
import os
jsonl_paths = []
for root, _, files in os.walk("/jsonl/full"):
for fname in files:
if fname.endswith(".jsonl"):
rel = os.path.relpath(os.path.join(root, fname), "/jsonl")
jsonl_paths.append(rel)
jsonl_paths.sort()
print(f"[orchestrator] found {len(jsonl_paths)} JSONL files to process")
n_total_done = 0
n_total_skipped = 0
total_mb = 0.0
errors = 0
for i, r in enumerate(embed_targeted_jsonl.map(jsonl_paths, return_exceptions=True)):
if isinstance(r, Exception):
errors += 1
print(f" [{i+1}/{len(jsonl_paths)}] ERROR: {r}")
continue
n_total_done += r.get("n_done", 0)
n_total_skipped += r.get("n_skipped", 0)
total_mb += r.get("total_mb", 0.0)
if (i + 1) % 25 == 0 or (i + 1) == len(jsonl_paths):
print(f" [{i+1}/{len(jsonl_paths)}] running totals: done={n_total_done} skipped={n_total_skipped} errors={errors} {total_mb/1024:.1f} GB")
return {
"jsonls": len(jsonl_paths),
"regions_done": n_total_done,
"regions_skipped": n_total_skipped,
"errors": errors,
"total_mb": total_mb,
}
@app.local_entrypoint()
def run_targeted_full():
"""
Launch the full targeted-region embed run. Use `modal run --detach` so it
keeps running after the local process exits.
modal run --detach modal/evo2_inference.py::run_targeted_full
"""
print("[local] submitting orchestrator to Modal — fan-out happens server-side")
result = orchestrate_targeted_full.remote()
print("\n=== DONE ===")
print(f" JSONL files: {result['jsonls']}")
print(f" regions saved: {result['regions_done']}")
print(f" regions skipped: {result['regions_skipped']} (already on volume)")
print(f" errors: {result['errors']}")
print(f" total volume: {result['total_mb']/1024:.1f} GB")
# ============================================================================
# Layer-26-only slicer — CPU job that reads all-32-layer npz from
# mgnify-embeddings-targeted, extracts the layer-26 slice, and writes a
# ~30x smaller npz to mgnify-embeddings-l26 for cheap teammate sharing.
# ============================================================================
embeddings_l26_vol = modal.Volume.from_name("mgnify-embeddings-l26", create_if_missing=True)
@app.function(
image=modal.Image.debian_slim().pip_install("numpy"),
cpu=2,
volumes={
"/in": embeddings_targeted_vol,
"/out": embeddings_l26_vol,
},
timeout=3600,
max_containers=16,
)
def extract_l26_batch(rel_paths: list[str]) -> dict:
"""Slice layer 26 from each input npz; write a compact per-region npz to the l26 volume.
rel_paths: list like ['AMR/MGYG.../MGYG..._00123_AMR.npz', ...] relative to volume root."""
import os
import numpy as np
n_done = 0
n_skipped = 0
n_errors = 0
total_mb_in = 0.0
total_mb_out = 0.0
for rel in rel_paths:
in_path = f"/in/{rel}"
out_path = f"/out/{rel}"
if not os.path.exists(in_path):
n_errors += 1
continue
if os.path.exists(out_path):
n_skipped += 1
continue
try:
with np.load(in_path, allow_pickle=False) as d:
stack = d["per_token_layer_activations_bf16"] # uint16 [32, seq_len, 4096]
l26 = stack[26].copy() # uint16 [seq_len, 4096]
passthrough = {
"layer_names": d["layer_names"],
"seq_len": d["seq_len"],
"hidden_size": d["hidden_size"],
"model_name": d["model_name"],
"metadata_json": d["metadata_json"],
}
os.makedirs(os.path.dirname(out_path), exist_ok=True)
np.savez_compressed(
out_path,
layer26_activations_bf16=l26, # bit-pattern of bf16 stored as uint16
layer26_dtype="bfloat16",
source_layer_index=np.int32(26),
source_layer_name="blocks-26",
**passthrough,
)
total_mb_in += os.path.getsize(in_path) / 1e6
total_mb_out += os.path.getsize(out_path) / 1e6
n_done += 1
except Exception as e:
print(f" ERROR on {rel}: {e}")
n_errors += 1
embeddings_l26_vol.commit()
return {
"n_done": n_done,
"n_skipped": n_skipped,
"n_errors": n_errors,
"total_mb_in": total_mb_in,
"total_mb_out": total_mb_out,
}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/in": embeddings_targeted_vol},
timeout=86400,
)
def orchestrate_l26_extract(batch_size: int = 50) -> dict:
"""List every committed all-32-layer npz on /in, batch by N, fan out to extract_l26_batch."""
import os
paths = []
for root, _, files in os.walk("/in"):
for fname in files:
if fname.endswith(".npz"):
rel = os.path.relpath(os.path.join(root, fname), "/in")
paths.append(rel)
paths.sort()
print(f"[orchestrator] found {len(paths)} all-32-layer npz files to slice")
# batch into chunks for fewer container starts and one commit per batch
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator] dispatching {len(batches)} batches of up to {batch_size}")
n_done = 0
n_skipped = 0
n_errors = 0
total_mb_in = 0.0
total_mb_out = 0.0
for i, r in enumerate(extract_l26_batch.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
n_errors += r.get("n_errors", 0)
total_mb_in += r.get("total_mb_in", 0.0)
total_mb_out += r.get("total_mb_out", 0.0)
if (i + 1) % 5 == 0 or (i + 1) == len(batches):
print(f" [{i+1}/{len(batches)}] running totals: done={n_done} skipped={n_skipped} errors={n_errors} in={total_mb_in/1024:.1f} GB → out={total_mb_out/1024:.2f} GB")
return {
"regions_total": len(paths),
"n_done": n_done,
"n_skipped": n_skipped,
"n_errors": n_errors,
"total_mb_in": total_mb_in,
"total_mb_out": total_mb_out,
}
@app.local_entrypoint()
def run_l26_extract():
"""
Slice layer-26 out of every committed all-32-layer npz on
mgnify-embeddings-targeted, write to mgnify-embeddings-l26.
modal run --detach modal/evo2_inference.py::run_l26_extract
"""
print("[local] submitting layer-26 slicer orchestrator to Modal")
result = orchestrate_l26_extract.remote()
print("\n=== DONE ===")
print(f" source npz scanned: {result['regions_total']}")
print(f" newly sliced: {result['n_done']}")
print(f" skipped (already done): {result['n_skipped']}")
print(f" errors: {result['n_errors']}")
print(f" input bytes read: {result['total_mb_in']/1024:.2f} GB")
print(f" output bytes written: {result['total_mb_out']/1024:.2f} GB")
print(f" compression ratio: {result['total_mb_in']/max(result['total_mb_out'],1):.1f}x smaller")
# ============================================================================
# Modal-side packager: tar/zip the layer-26 volume into a single file on the
# same volume, then we download just that one big file (much faster than
# per-file `modal volume get`, which serializes per-file requests).
# ============================================================================
@app.function(
image=modal.Image.debian_slim(),
cpu=4,
volumes={"/vol": embeddings_l26_vol},
timeout=3600,
)
def pack_l26_archive(out_name: str = "embeddings_l26.zip") -> dict:
"""Walk /vol, write all .npz into a single zip archive (store mode — npz already gzipped)
at /vol/{out_name}. Returns size + file count."""
import os
import zipfile
import time
out_path = f"/vol/{out_name}"
if os.path.exists(out_path):
os.remove(out_path)
n = 0
bytes_in = 0
t0 = time.time()
with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_STORED) as zf:
for root, _, files in os.walk("/vol"):
for fname in files:
if not fname.endswith(".npz"):
continue
full = os.path.join(root, fname)
arcname = os.path.relpath(full, "/vol")
zf.write(full, arcname)
bytes_in += os.path.getsize(full)
n += 1
if n % 50 == 0:
print(f" packed {n} files, {bytes_in/1e9:.2f} GB raw...")
out_size = os.path.getsize(out_path)
embeddings_l26_vol.commit()
return {
"archive_path": out_path,
"n_files": n,
"bytes_raw": bytes_in,
"bytes_archive": out_size,
"elapsed_s": time.time() - t0,
}
@app.local_entrypoint()
def pack_and_report():
"""Pack the l26 volume into a single zip on the volume itself, ready for one-shot download."""
print("[local] packing layer-26 npz files into single zip on Modal volume...")
r = pack_l26_archive.remote()
print("\n=== PACKED ===")
print(f" archive path on volume: {r['archive_path']}")
print(f" files packed: {r['n_files']}")
print(f" raw size: {r['bytes_raw']/1e9:.2f} GB")
print(f" archive size: {r['bytes_archive']/1e9:.2f} GB")
print(f" elapsed (server-side): {r['elapsed_s']:.0f} s")
print(f"\nDownload locally with:")
print(f" modal volume get mgnify-embeddings-l26 /embeddings_l26.zip /home/ror25cal/MGnify/data/embeddings_l26.zip")
# ============================================================================
# Modal-side HF upload: stream the layer-26 volume directly to a Hugging Face
# Dataset repo. Uses HF token from the existing 'huggingface' Modal Secret.
# Avoids the local download entirely — Modal egress → HF ingress is fast.
# ============================================================================
@app.function(
image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"),
cpu=4,
volumes={"/vol": embeddings_l26_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=7200,
)
def upload_l26_to_hf(repo_name: str = "mgnify-evo2-l26-amr-pilot", private: bool = True) -> dict:
"""Push every .npz under /vol to a HF Dataset repo. Uses upload_large_folder for
parallel + resumable LFS uploads."""
import os
import time
from huggingface_hub import HfApi
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if not token:
# Workaround for the legacy 'huggingface' Modal secret which was created
# with key/value swapped (env var name *is* the token, value == "HF_TOKEN").
for k, v in os.environ.items():
if k.startswith("hf_") and v == "HF_TOKEN":
token = k
break
if not token:
raise RuntimeError("HF_TOKEN env var missing — check the 'huggingface' Modal Secret")
api = HfApi(token=token)
me = api.whoami()
user = me.get("name")
repo_id = f"{user}/{repo_name}"
print(f"[hf] authenticated as: {user}")
print(f"[hf] target repo: {repo_id} (private={private})")
api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
print(f"[hf] repo ready: https://huggingface.co/datasets/{repo_id}")
# count what we're about to upload
n_files = 0
bytes_total = 0
for root, _, files in os.walk("/vol"):
for f in files:
if f.endswith(".npz"):
n_files += 1
bytes_total += os.path.getsize(os.path.join(root, f))
print(f"[hf] uploading {n_files} files, {bytes_total/1e9:.2f} GB total")
t0 = time.time()
api.upload_large_folder(
repo_id=repo_id,
repo_type="dataset",
folder_path="/vol",
allow_patterns=["**/*.npz"],
print_report=True,
)
elapsed = time.time() - t0
return {
"repo_id": repo_id,
"repo_url": f"https://huggingface.co/datasets/{repo_id}",
"n_files": n_files,
"bytes_total": bytes_total,
"elapsed_s": elapsed,
}
@app.local_entrypoint()
def push_l26_to_hf(repo_name: str = "mgnify-evo2-l26-amr-pilot", private: bool = True):
"""Push the layer-26 volume directly to a HF Dataset repo (no local download).
modal run modal/evo2_inference.py::push_l26_to_hf
modal run modal/evo2_inference.py::push_l26_to_hf --repo-name foo --no-private
"""
print(f"[local] launching HF push (repo={repo_name}, private={private})")
r = upload_l26_to_hf.remote(repo_name=repo_name, private=private)
print("\n=== UPLOADED ===")
print(f" repo: {r['repo_url']}")
print(f" files: {r['n_files']}")
print(f" size: {r['bytes_total']/1e9:.2f} GB")
print(f" elapsed: {r['elapsed_s']:.0f} s ({r['bytes_total']/1e6/max(r['elapsed_s'],1):.1f} MB/s)")
# ============================================================================
# Lean targeted pipeline: 5 layers, no compression, no SAE.
# Replaces embed_targeted_jsonl for cost-sensitive reruns. Saves layers
# 14, 20, 24, 26, 28 as bf16-as-uint16 .npz with NO gzip — np.savez_compressed
# was the dominant cost in the prior run (gzip on bf16 noise = 30s/region of
# CPU while H100 idled). Uncompressed is ~3-5x faster and only ~30% bigger.
# ============================================================================
embeddings_lean_vol = modal.Volume.from_name("mgnify-embeddings-lean", create_if_missing=True)
LEAN_LAYERS: list[int] = [14, 20, 24, 26, 28]
def _get_evo2_only():
"""Lighter than _get_models — skips SAE weight load. Cached at module level."""
try:
if _CACHED_EVO2_LEAN is not None:
return _CACHED_EVO2_LEAN, _CACHED_DEVICE_LEAN, _CACHED_MODULE_DICT_LEAN
except NameError:
pass
from evo2 import Evo2
print("[container] loading Evo2 7B-262k (no SAE)")
evo2 = Evo2("evo2_7b_262k")
device = next(evo2.model.parameters()).device
module_dict = {}
def recurse(m, prefix=""):
for n, c in m.named_children():
module_dict[prefix + n] = c
recurse(c, prefix + n + "-")
recurse(evo2.model)
globals()["_CACHED_EVO2_LEAN"] = evo2
globals()["_CACHED_DEVICE_LEAN"] = device
globals()["_CACHED_MODULE_DICT_LEAN"] = module_dict
return evo2, device, module_dict
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings_lean": embeddings_lean_vol,
"/jsonl": jsonl_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=7200,
max_containers=16,
)
def embed_targeted_lean(jsonl_rel_paths) -> dict:
"""
Process all records in a *batch* of JSONL files. Single volume.commit() at end.
Accepts either str (single JSONL, back-compat) or list[str] (batch).
Saves /embeddings_lean/{label}/{mag}/{region}.npz per region.
"""
import json
import os
import time
import numpy as np
import torch
if isinstance(jsonl_rel_paths, str):
jsonl_rel_paths = [jsonl_rel_paths]
t_load_start = time.time()
evo2, device, module_dict = _get_evo2_only()
t_load = time.time() - t_load_start
layer_names = [f"blocks-{i}" for i in LEAN_LAYERS]
cache: dict = {}
def make_hook(name):
def hook(module, inp, out):
cache[name] = (out[0] if isinstance(out, tuple) else out).detach()
return hook
handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names]
n_done = 0
n_skipped = 0
n_missing_jsonl = 0
total_mb = 0.0
per_region_times: list[float] = []
try:
for jsonl_rel in jsonl_rel_paths:
src_path = f"/jsonl/{jsonl_rel}"
if not os.path.exists(src_path):
n_missing_jsonl += 1
continue
with open(src_path) as f:
records = [json.loads(line) for line in f if line.strip()]
if not records:
continue
for rec in records:
label_folder = rec["label"] if rec["is_positive"] else "MISC"
mag_id = rec["mag_id"]
region_id = rec["region_id"]
out_dir = f"/embeddings_lean/{label_folder}/{mag_id}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{region_id}.npz"
if os.path.exists(out_path):
n_skipped += 1
continue
t_region = time.time()
seq = rec["sequence"]
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache[layer_names[0]].shape[1]
hidden = evo2.model.config.hidden_size
stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16)
for i, ln in enumerate(layer_names):
stack[i] = cache[ln][0].to(torch.bfloat16).cpu()
del cache[ln]
torch.cuda.empty_cache()
stack_uint16 = stack.view(torch.uint16).numpy()
meta = {k: v for k, v in rec.items() if k != "sequence"}
np.savez( # NO compression — gzip was the bottleneck
out_path,
per_token_layer_activations_bf16=stack_uint16,
per_token_layer_activations_dtype="bfloat16",
layer_names=np.array(layer_names),
layer_indices=np.array(LEAN_LAYERS, dtype=np.int32),
seq_len=np.int32(seq_len),
hidden_size=np.int32(hidden),
model_name="evo2_7b_262k",
metadata_json=np.array(json.dumps(meta)),
)
sz = os.path.getsize(out_path)
total_mb += sz / 1e6
n_done += 1
per_region_times.append(time.time() - t_region)
del stack, stack_uint16, input_ids
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
cache.clear()
torch.cuda.empty_cache()
t_commit_start = time.time()
embeddings_lean_vol.commit()
t_commit = time.time() - t_commit_start
return {
"n_jsonls": len(jsonl_rel_paths),
"n_missing_jsonl": n_missing_jsonl,
"n_done": n_done,
"n_skipped": n_skipped,
"total_mb": total_mb,
"model_load_s": t_load,
"commit_s": t_commit,
"per_region_s": per_region_times,
"mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None,
}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/jsonl": jsonl_vol},
timeout=86400,
)
def orchestrate_lean_full(batch_size: int = 50) -> dict:
"""Walks JSONL volume, batches paths by `batch_size`, fans out to embed_targeted_lean.
Batching keeps the per-call commit overhead amortized."""
import os
jsonl_paths = []
for root, _, files in os.walk("/jsonl/full"):
for fname in files:
if fname.endswith(".jsonl"):
rel = os.path.relpath(os.path.join(root, fname), "/jsonl")
jsonl_paths.append(rel)
jsonl_paths.sort()
batches = [jsonl_paths[i:i + batch_size] for i in range(0, len(jsonl_paths), batch_size)]
print(f"[orchestrator-lean] {len(jsonl_paths)} JSONLs → {len(batches)} batches of up to {batch_size}")
n_done = 0
n_skipped = 0
total_mb = 0.0
errors = 0
region_time_samples: list[float] = []
commit_time_samples: list[float] = []
for i, r in enumerate(embed_targeted_lean.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
errors += 1
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
total_mb += r.get("total_mb", 0.0)
if r.get("per_region_s"):
region_time_samples.extend(r["per_region_s"])
if r.get("commit_s") is not None:
commit_time_samples.append(r["commit_s"])
if (i + 1) % 5 == 0 or (i + 1) == len(batches):
mean_t = sum(region_time_samples) / max(len(region_time_samples), 1)
mean_c = sum(commit_time_samples) / max(len(commit_time_samples), 1)
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} "
f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f} mean_commit_s={mean_c:.2f}")
return {
"jsonls": len(jsonl_paths),
"batches": len(batches),
"regions_done": n_done,
"regions_skipped": n_skipped,
"errors": errors,
"total_mb": total_mb,
"mean_region_s": sum(region_time_samples) / max(len(region_time_samples), 1),
"mean_commit_s": sum(commit_time_samples) / max(len(commit_time_samples), 1),
}
@app.local_entrypoint()
def pilot_lean_batched(n_jsonls: int = 100, batch_size: int = 50):
"""Run the *batched* lean pipeline on N JSONLs to measure realistic commit overhead.
Then download one output file and verify schema."""
import os
import time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/full"
paths = []
for label in ["AMR", "MISC", "VIRULENCE", "STRESS"]:
d = os.path.join(base, label)
if os.path.isdir(d):
for fname in sorted(os.listdir(d)):
if fname.endswith(".jsonl"):
paths.append(f"full/{label}/{fname}")
paths = paths[:n_jsonls]
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[pilot-batched] {len(paths)} JSONLs in {len(batches)} batches of up to {batch_size}")
t0 = time.time()
results = list(embed_targeted_lean.map(batches, return_exceptions=True))
wall = time.time() - t0
ok = [r for r in results if not isinstance(r, Exception)]
n_done = sum(r["n_done"] for r in ok)
total_mb = sum(r["total_mb"] for r in ok)
all_region_times = [t for r in ok for t in r.get("per_region_s") or []]
commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None]
load_times = [r["model_load_s"] for r in ok]
mean_region = sum(all_region_times) / max(len(all_region_times), 1)
mean_commit = sum(commit_times) / max(len(commit_times), 1)
print(f"\n=== PILOT-BATCHED SUMMARY ===")
print(f" records: {n_done}")
print(f" total volume size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.1f} MB/record)")
print(f" wall clock (parallel): {wall:.0f} s across {len(batches)} batches × {len(results)} workers")
print(f" per-region inference: {mean_region:.2f} s avg (min {min(all_region_times):.2f}, max {max(all_region_times):.2f})")
print(f" per-batch commit: {mean_commit:.2f} s avg (min {min(commit_times):.2f}, max {max(commit_times):.2f})")
print(f" per-call model load: {sum(load_times)/max(len(load_times),1):.1f} s avg (cold start)")
# Project full run with batching
full_records = 5483
full_n_jsonls = 2416
n_workers = 16
full_batches = (full_n_jsonls + batch_size - 1) // batch_size
inference_compute_s = full_records * mean_region
commit_compute_s = full_batches * mean_commit
cold_start_s = sum(load_times) / max(len(load_times), 1) * n_workers # one cold start per worker
total_compute_s = inference_compute_s + commit_compute_s + cold_start_s
wall_proj = total_compute_s / n_workers
cost = (total_compute_s / 3600) * 4.50
print(f"\n PROJECTION (5483 records, {n_workers}× H100, batch_size={batch_size}):")
print(f" inference compute: {inference_compute_s:6.0f} s ({inference_compute_s/total_compute_s*100:.0f}%)")
print(f" commit compute: {commit_compute_s:6.0f} s ({commit_compute_s/total_compute_s*100:.0f}%)")
print(f" cold-start total: {cold_start_s:6.0f} s ({cold_start_s/total_compute_s*100:.0f}%)")
print(f" estimated wall clock: {wall_proj/60:.1f} min")
print(f" estimated cost: ${cost:.2f}")
print(f" estimated total size: {total_mb/max(n_done,1) * full_records / 1024:.1f} GB")
@app.local_entrypoint()
def run_lean_full(batch_size: int = 50):
"""Launch the full lean run detached.
modal run --detach modal/evo2_inference.py::run_lean_full
modal run --detach modal/evo2_inference.py::run_lean_full --batch-size 50
"""
print(f"[local] submitting orchestrator (batch_size={batch_size}) to Modal")
r = orchestrate_lean_full.remote(batch_size=batch_size)
print("\n=== DONE ===")
print(f" JSONL files: {r['jsonls']} in {r['batches']} batches")
print(f" regions saved: {r['regions_done']}")
print(f" regions skipped: {r['regions_skipped']} (already on volume)")
print(f" errors: {r['errors']}")
print(f" total volume size: {r['total_mb']/1024:.1f} GB")
print(f" mean per-region: {r['mean_region_s']:.2f} s")
print(f" mean per-batch commit:{r['mean_commit_s']:.2f} s")
@app.local_entrypoint()
def pilot_lean(mag_id: str = "MGYG000516287"):
"""Pilot the lean pipeline on a single MAG (across all labels). Reports empirical
timing + projects full-run cost.
"""
import time
pilot_jsonls = [f"full/{label}/{mag_id}.jsonl" for label in ["AMR", "VIRULENCE", "STRESS", "MISC"]]
print(f"[pilot] running on MAG {mag_id} across labels (skipping any missing)")
results = []
t0 = time.time()
for path in pilot_jsonls:
try:
r = embed_targeted_lean.remote(path)
print(f" {path}: done={r['n_done']} skipped={r['n_skipped']} "
f"model_load={r.get('model_load_s', 0):.1f}s "
f"mean_region={r.get('mean_per_region_s') or 0:.2f}s "
f"size={r['total_mb']:.1f} MB")
results.append(r)
except Exception as e:
print(f" {path}: SKIPPED ({e})")
wall = time.time() - t0
# Aggregate timings
all_region_times = [t for r in results for t in r.get("per_region_s") or []]
n_total = sum(r["n_done"] for r in results)
bytes_total = sum(r["total_mb"] for r in results)
print("\n=== PILOT SUMMARY ===")
print(f" records processed: {n_total}")
print(f" total volume size: {bytes_total:.1f} MB ({bytes_total/max(n_total,1):.1f} MB/record avg)")
print(f" wall clock (single worker): {wall:.0f} s")
if all_region_times:
mean_t = sum(all_region_times) / len(all_region_times)
print(f" per-region time: {mean_t:.2f} s avg "
f"(min {min(all_region_times):.2f}, max {max(all_region_times):.2f})")
# Project full run cost
full_records = 5483
full_compute_s = full_records * mean_t # serial-equivalent compute time
n_workers = 16
# Cold start ~60s per worker, contributes once each
cold_start_s = 60
wall_proj = cold_start_s + full_compute_s / n_workers
h100_hours = (full_compute_s / 3600 + (cold_start_s * n_workers) / 3600)
cost_proj = h100_hours * 4.50
print(f"\n PROJECTION (5483 regions, {n_workers}× H100):")
print(f" estimated wall clock: {wall_proj/60:.1f} min")
print(f" estimated cost: ${cost_proj:.2f}")
print(f" estimated total size: {bytes_total/max(n_total,1) * full_records / 1024:.1f} GB")
# ============================================================================
# Slice layer 26 out of the 5-layer lean volume → new volume → push to HF.
# Self-contained on Modal: parallel slicer workers + single uploader.
# ============================================================================
embeddings_l26_lean_vol = modal.Volume.from_name("mgnify-embeddings-l26-lean", create_if_missing=True)
@app.function(
image=modal.Image.debian_slim().pip_install("numpy"),
cpu=2,
volumes={
"/in": embeddings_lean_vol,
"/out": embeddings_l26_lean_vol,
},
timeout=3600,
max_containers=16,
)
def slice_l26_from_lean_batch(rel_paths: list[str]) -> dict:
"""Slice layer 26 from each lean npz. Single commit per batch."""
import os
import numpy as np
n_done = 0
n_skipped = 0
n_errors = 0
total_mb = 0.0
for rel in rel_paths:
in_path = f"/in/{rel}"
out_path = f"/out/{rel}"
if os.path.exists(out_path):
n_skipped += 1
continue
if not os.path.exists(in_path):
n_errors += 1
continue
try:
with np.load(in_path, allow_pickle=False) as d:
stack = d["per_token_layer_activations_bf16"] # [5, seq_len, 4096] uint16
layer_indices = list(int(x) for x in d["layer_indices"])
pos = layer_indices.index(26)
l26 = stack[pos].copy()
passthrough = {
"seq_len": d["seq_len"],
"hidden_size": d["hidden_size"],
"model_name": d["model_name"],
"metadata_json": d["metadata_json"],
}
os.makedirs(os.path.dirname(out_path), exist_ok=True)
np.savez(
out_path,
layer26_activations_bf16=l26,
layer26_dtype="bfloat16",
source_layer_index=np.int32(26),
source_layer_name="blocks-26",
**passthrough,
)
total_mb += os.path.getsize(out_path) / 1e6
n_done += 1
except Exception as e:
print(f" ERROR on {rel}: {e}")
n_errors += 1
embeddings_l26_lean_vol.commit()
return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/in": embeddings_lean_vol},
timeout=86400,
)
def orchestrate_l26_lean_slice(batch_size: int = 50) -> dict:
"""List lean npz files, batch them, fan out to slicer workers."""
import os
paths = []
for root, _, files in os.walk("/in"):
for fname in files:
if fname.endswith(".npz"):
rel = os.path.relpath(os.path.join(root, fname), "/in")
paths.append(rel)
paths.sort()
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator-l26-slice] {len(paths)} lean npz → {len(batches)} batches")
n_done = 0
n_skipped = 0
n_errors = 0
total_mb_in = 0.0
total_mb_out = 0.0
for i, r in enumerate(slice_l26_from_lean_batch.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
n_errors += r.get("n_errors", 0)
total_mb_out += r.get("total_mb", 0.0)
if (i + 1) % 5 == 0 or (i + 1) == len(batches):
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB")
return {
"files_total": len(paths),
"n_done": n_done,
"n_skipped": n_skipped,
"n_errors": n_errors,
"total_mb_out": total_mb_out,
}
@app.function(
image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"),
cpu=4,
volumes={"/vol": embeddings_l26_lean_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=21600,
)
def upload_l26_lean_to_hf(repo_name: str = "mgnify-evo2-l26-full", private: bool = True) -> dict:
"""Push the layer-26-lean volume to HF Datasets via upload_large_folder."""
import os
import time
from huggingface_hub import HfApi
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if not token:
# workaround for legacy 'huggingface' Modal secret with swapped key/value
for k, v in os.environ.items():
if k.startswith("hf_") and v == "HF_TOKEN":
token = k
break
if not token:
raise RuntimeError("HF_TOKEN env var missing — check the 'huggingface' Modal Secret")
api = HfApi(token=token)
me = api.whoami()
user = me.get("name")
repo_id = f"{user}/{repo_name}"
print(f"[hf] authenticated as: {user}")
print(f"[hf] target repo: {repo_id} (private={private})")
api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
print(f"[hf] repo ready: https://huggingface.co/datasets/{repo_id}")
n_files = 0
bytes_total = 0
for root, _, files in os.walk("/vol"):
for f in files:
if f.endswith(".npz"):
n_files += 1
bytes_total += os.path.getsize(os.path.join(root, f))
print(f"[hf] uploading {n_files} files, {bytes_total/1e9:.2f} GB total")
t0 = time.time()
api.upload_large_folder(
repo_id=repo_id,
repo_type="dataset",
folder_path="/vol",
allow_patterns=["**/*.npz"],
print_report=True,
)
elapsed = time.time() - t0
return {
"repo_id": repo_id,
"repo_url": f"https://huggingface.co/datasets/{repo_id}",
"n_files": n_files,
"bytes_total": bytes_total,
"elapsed_s": elapsed,
}
@app.function(
image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"),
secrets=[modal.Secret.from_name("huggingface")],
timeout=120,
)
def set_hf_dataset_visibility(repo_name: str, private: bool) -> dict:
"""Toggle visibility of an HF dataset. Used to flip private→public after upload."""
import os
from huggingface_hub import HfApi
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if not token:
for k, v in os.environ.items():
if k.startswith("hf_") and v == "HF_TOKEN":
token = k
break
if not token:
raise RuntimeError("HF_TOKEN missing")
api = HfApi(token=token)
me = api.whoami()
repo_id = f"{me['name']}/{repo_name}"
# newer huggingface_hub uses update_repo_settings; older has update_repo_visibility
if hasattr(api, "update_repo_settings"):
api.update_repo_settings(repo_id=repo_id, repo_type="dataset", private=private)
else:
api.update_repo_visibility(repo_id=repo_id, repo_type="dataset", private=private)
info = api.repo_info(repo_id=repo_id, repo_type="dataset")
return {"repo_id": repo_id, "private": getattr(info, "private", None), "url": f"https://huggingface.co/datasets/{repo_id}"}
@app.local_entrypoint()
def make_l26_dataset_public(repo_name: str = "mgnify-evo2-l26-full"):
"""Flip the layer-26 dataset to public."""
r = set_hf_dataset_visibility.remote(repo_name=repo_name, private=False)
print(f" repo: {r['repo_id']}")
print(f" private: {r['private']}")
print(f" url: {r['url']}")
@app.local_entrypoint()
def push_l26_lean(repo_name: str = "mgnify-evo2-l26-full", batch_size: int = 50, private: bool = True):
"""Slice layer 26 from the lean volume into its own volume, then upload to HF.
modal run --detach modal/evo2_inference.py::push_l26_lean
modal run --detach modal/evo2_inference.py::push_l26_lean --no-private
"""
print("[1/2] slicing layer-26 from lean volume on Modal…")
s = orchestrate_l26_lean_slice.remote(batch_size=batch_size)
print(f" scanned {s['files_total']} lean npz")
print(f" newly sliced: {s['n_done']}")
print(f" skipped: {s['n_skipped']}")
print(f" errors: {s['n_errors']}")
print(f" l26 volume size: {s['total_mb_out']/1024:.2f} GB")
print("\n[2/2] pushing layer-26 volume to HF Datasets…")
u = upload_l26_lean_to_hf.remote(repo_name=repo_name, private=private)
print(f"\n=== UPLOADED ===")
print(f" repo: {u['repo_url']}")
print(f" files: {u['n_files']}")
print(f" size: {u['bytes_total']/1e9:.2f} GB")
print(f" elapsed: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)")
# =============================================================
# VFDB virulence pipeline (mirror of embed_targeted_lean)
# =============================================================
# Same Evo2 forward pass + 5-layer extraction as the MGnify lean pipeline.
# Differences:
# - input JSONLs live at /jsonl/vfdb/<species_slug>.jsonl
# - outputs at /embeddings_lean/vfdb/{label_folder}/{group}/{region_id}.npz
# label_folder: "VIRULENCE" (positive) or "negative" (no MGnify-MISC collision)
# group: species_slug (positives) or mag_id (negatives)
# - record schema: positives lack mag_id natively; pre-processing in
# scripts/extract_vfdb_negatives.py upload step injects species_slug.
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings_lean": embeddings_lean_vol,
"/jsonl": jsonl_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=7200,
max_containers=16,
)
def embed_vfdb_lean(jsonl_rel_paths) -> dict:
"""VFDB-targeted version of embed_targeted_lean. Outputs under /embeddings_lean/vfdb/."""
import json
import os
import time
import numpy as np
import torch
if isinstance(jsonl_rel_paths, str):
jsonl_rel_paths = [jsonl_rel_paths]
t_load_start = time.time()
evo2, device, module_dict = _get_evo2_only()
t_load = time.time() - t_load_start
layer_names = [f"blocks-{i}" for i in LEAN_LAYERS]
cache: dict = {}
def make_hook(name):
def hook(module, inp, out):
cache[name] = (out[0] if isinstance(out, tuple) else out).detach()
return hook
handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names]
n_done = 0
n_skipped = 0
n_missing_jsonl = 0
total_mb = 0.0
per_region_times: list[float] = []
seq_lens: list[int] = []
try:
for jsonl_rel in jsonl_rel_paths:
src_path = f"/jsonl/{jsonl_rel}"
if not os.path.exists(src_path):
n_missing_jsonl += 1
continue
with open(src_path) as f:
records = [json.loads(line) for line in f if line.strip()]
if not records:
continue
for rec in records:
# VFDB-aware path layout
label_folder = "VIRULENCE" if rec["is_positive"] else "negative"
group = rec.get("mag_id") or rec.get("species") or "UNKNOWN"
region_id = rec["region_id"]
out_dir = f"/embeddings_lean/vfdb/{label_folder}/{group}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{region_id}.npz"
if os.path.exists(out_path):
n_skipped += 1
continue
t_region = time.time()
seq = rec["sequence"]
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache[layer_names[0]].shape[1]
hidden = evo2.model.config.hidden_size
stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16)
for i, ln in enumerate(layer_names):
stack[i] = cache[ln][0].to(torch.bfloat16).cpu()
del cache[ln]
torch.cuda.empty_cache()
stack_uint16 = stack.view(torch.uint16).numpy()
meta = {k: v for k, v in rec.items() if k != "sequence"}
np.savez(
out_path,
per_token_layer_activations_bf16=stack_uint16,
per_token_layer_activations_dtype="bfloat16",
layer_names=np.array(layer_names),
layer_indices=np.array(LEAN_LAYERS, dtype=np.int32),
seq_len=np.int32(seq_len),
hidden_size=np.int32(hidden),
model_name="evo2_7b_262k",
metadata_json=np.array(json.dumps(meta)),
)
sz = os.path.getsize(out_path)
total_mb += sz / 1e6
n_done += 1
per_region_times.append(time.time() - t_region)
seq_lens.append(seq_len)
del stack, stack_uint16, input_ids
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
cache.clear()
torch.cuda.empty_cache()
t_commit_start = time.time()
embeddings_lean_vol.commit()
t_commit = time.time() - t_commit_start
return {
"n_jsonls": len(jsonl_rel_paths),
"n_missing_jsonl": n_missing_jsonl,
"n_done": n_done,
"n_skipped": n_skipped,
"total_mb": total_mb,
"model_load_s": t_load,
"commit_s": t_commit,
"per_region_s": per_region_times,
"seq_lens": seq_lens,
"mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None,
}
@app.local_entrypoint()
def pilot_vfdb_lean(target_records: int = 250, batch_size: int = 2):
"""Run a small VFDB pilot to measure per-region time + commit time, then project full cost.
Picks species files greedily to land near `target_records` total — biased toward
small/medium species for cost containment, while including ≥1 medium-sized file
so the per-batch commit time reflects realistic output volume.
modal run modal/evo2_inference.py::pilot_vfdb_lean
modal run modal/evo2_inference.py::pilot_vfdb_lean --target-records 500 --batch-size 4
"""
import os, time, json
base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready"
if not os.path.isdir(base):
raise FileNotFoundError(f"VFDB JSONLs not found at {base}; run "
"extract_vfdb_negatives.py + the modal-prep step first.")
species_files = []
for fname in sorted(os.listdir(base)):
if not fname.endswith(".jsonl"):
continue
path = os.path.join(base, fname)
with open(path) as f:
n = sum(1 for line in f if line.strip())
species_files.append((n, fname))
species_files.sort() # smallest first
# Greedy pick of small-to-medium files until we hit the target.
# Skip files that would by themselves blow the budget by >2×.
chosen = []
total = 0
for n, fname in species_files:
if total >= target_records:
break
if n > target_records * 2 and chosen:
continue
chosen.append((n, fname))
total += n
if not chosen:
chosen = [species_files[0]]
print(f"[pilot-vfdb] selected {len(chosen)} species:")
total_records = 0
for n, fname in chosen:
print(f" {fname:40s} {n} records")
total_records += n
print(f" total pilot records: {total_records}")
# Upload selected JSONLs to jsonl_vol under vfdb/
print(f"\n[pilot-vfdb] uploading {len(chosen)} JSONLs to volume mgnify-targeted-jsonl ...")
upload_t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for _, fname in chosen:
local_path = os.path.join(base, fname)
remote_path = f"vfdb/{fname}"
batch.put_file(local_path, remote_path)
print(f" uploaded in {time.time()-upload_t0:.0f} s")
rel_paths = [f"vfdb/{fname}" for _, fname in chosen]
batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)]
print(f"\n[pilot-vfdb] {len(rel_paths)} JSONLs in {len(batches)} batch(es) of up to {batch_size}")
t0 = time.time()
results = list(embed_vfdb_lean.map(batches, return_exceptions=True))
wall = time.time() - t0
ok = [r for r in results if not isinstance(r, Exception)]
n_done = sum(r["n_done"] for r in ok)
total_mb = sum(r["total_mb"] for r in ok)
region_times = [t for r in ok for t in r.get("per_region_s") or []]
seq_lens = [s for r in ok for s in r.get("seq_lens") or []]
commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None]
load_times = [r["model_load_s"] for r in ok]
if not region_times:
print("ERROR: no records processed in pilot")
return
mean_region = sum(region_times) / len(region_times)
mean_commit = sum(commit_times) / max(len(commit_times), 1)
mean_load = sum(load_times) / max(len(load_times), 1)
mean_seqlen = sum(seq_lens) / len(seq_lens)
print(f"\n=== VFDB PILOT RESULTS ===")
print(f" records processed: {n_done}")
print(f" total output size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.2f} MB/record)")
print(f" wall clock (parallel): {wall:.0f} s across {len(batches)} batch(es)")
print(f" per-region inference: {mean_region:.2f} s avg "
f"(min {min(region_times):.2f}, max {max(region_times):.2f}, p95 {sorted(region_times)[int(len(region_times)*0.95)]:.2f})")
print(f" per-region seq len: {mean_seqlen:.0f} bp avg")
print(f" per-batch commit: {mean_commit:.2f} s avg "
f"(min {min(commit_times):.2f}, max {max(commit_times):.2f})")
print(f" per-call model load: {mean_load:.1f} s avg")
# Cost projection for full VFDB run
full_records = 14695
n_workers = 16
h100_rate = 4.50 # $/hr — same as projection in pilot_lean_batched
full_n_jsonls = 34
full_batches = (full_n_jsonls + batch_size - 1) // batch_size
inference_compute_s = full_records * mean_region
commit_compute_s = full_batches * mean_commit
cold_start_s = mean_load * min(n_workers, full_batches)
total_compute_s = inference_compute_s + commit_compute_s + cold_start_s
wall_proj = total_compute_s / min(n_workers, full_batches)
cost = (total_compute_s / 3600) * h100_rate
output_size_gb = (total_mb / max(n_done, 1)) * full_records / 1024
print(f"\n PROJECTION ({full_records} records, {n_workers}× H100, batch_size={batch_size}, "
f"H100=${h100_rate:.2f}/hr):")
print(f" inference compute: {inference_compute_s:7.0f} s ({inference_compute_s/total_compute_s*100:.0f}%)")
print(f" commit compute: {commit_compute_s:7.0f} s ({commit_compute_s/total_compute_s*100:.0f}%)")
print(f" cold-start total: {cold_start_s:7.0f} s ({cold_start_s/total_compute_s*100:.0f}%)")
print(f" estimated wall clock: {wall_proj/60:5.1f} min")
print(f" estimated cost: ${cost:.2f}")
print(f" estimated total size: {output_size_gb:.1f} GB")
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/jsonl": jsonl_vol},
timeout=86400,
)
def orchestrate_vfdb_lean(batch_size: int = 2) -> dict:
"""Walks /jsonl/vfdb/, batches paths by `batch_size`, fans out to embed_vfdb_lean.
Mirror of orchestrate_lean_full but for VFDB."""
import os
jsonl_paths = []
for root, _, files in os.walk("/jsonl/vfdb"):
for fname in files:
if fname.endswith(".jsonl"):
rel = os.path.relpath(os.path.join(root, fname), "/jsonl")
jsonl_paths.append(rel)
jsonl_paths.sort()
batches = [jsonl_paths[i:i + batch_size] for i in range(0, len(jsonl_paths), batch_size)]
print(f"[orchestrator-vfdb] {len(jsonl_paths)} JSONLs → {len(batches)} batches of up to {batch_size}")
n_done = 0
n_skipped = 0
total_mb = 0.0
errors = 0
region_time_samples: list[float] = []
commit_time_samples: list[float] = []
for i, r in enumerate(embed_vfdb_lean.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
errors += 1
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
total_mb += r.get("total_mb", 0.0)
if r.get("per_region_s"):
region_time_samples.extend(r["per_region_s"])
if r.get("commit_s") is not None:
commit_time_samples.append(r["commit_s"])
mean_t = sum(region_time_samples) / max(len(region_time_samples), 1)
mean_c = sum(commit_time_samples) / max(len(commit_time_samples), 1)
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} "
f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f} mean_commit_s={mean_c:.2f}")
return {
"jsonls": len(jsonl_paths),
"batches": len(batches),
"regions_done": n_done,
"regions_skipped": n_skipped,
"errors": errors,
"total_mb": total_mb,
"mean_region_s": sum(region_time_samples) / max(len(region_time_samples), 1),
"mean_commit_s": sum(commit_time_samples) / max(len(commit_time_samples), 1),
}
@app.local_entrypoint()
def run_vfdb_lean(batch_size: int = 2):
"""Upload all VFDB JSONLs to volume + run the full lean pipeline.
modal run --detach modal/evo2_inference.py::run_vfdb_lean
modal run --detach modal/evo2_inference.py::run_vfdb_lean --batch-size 4
"""
import os, time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready"
if not os.path.isdir(base):
raise FileNotFoundError(f"VFDB JSONLs not found at {base}")
jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl"))
print(f"[run-vfdb] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl volume ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for fname in jsonls:
batch.put_file(os.path.join(base, fname), f"vfdb/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
print(f"\n[run-vfdb] launching orchestrator (batch_size={batch_size})")
r = orchestrate_vfdb_lean.remote(batch_size=batch_size)
print("\n=== DONE ===")
print(f" JSONL files: {r['jsonls']} in {r['batches']} batches")
print(f" regions saved: {r['regions_done']}")
print(f" regions skipped: {r['regions_skipped']} (already on volume)")
print(f" errors: {r['errors']}")
print(f" total volume size: {r['total_mb']/1024:.1f} GB")
print(f" mean per-region: {r['mean_region_s']:.2f} s")
print(f" mean per-batch commit: {r['mean_commit_s']:.2f} s")
# =============================================================
# VFDB layer-26 slice + HF push
# =============================================================
# Same logic as the MGnify l26 slicer, but reads only /in/vfdb/* and writes to a
# separate output volume so it doesn't mix with the MGnify l26 dataset.
embeddings_l26_vfdb_vol = modal.Volume.from_name(
"mgnify-embeddings-l26-vfdb", create_if_missing=True,
)
@app.function(
image=modal.Image.debian_slim().pip_install("numpy"),
cpu=2,
volumes={
"/in": embeddings_lean_vol,
"/out": embeddings_l26_vfdb_vol,
},
timeout=3600,
max_containers=16,
)
def slice_l26_vfdb_batch(rel_paths: list[str]) -> dict:
"""Slice layer 26 from each VFDB lean npz. Same schema as slice_l26_from_lean_batch."""
import os
import numpy as np
n_done = 0
n_skipped = 0
n_errors = 0
total_mb = 0.0
for rel in rel_paths:
in_path = f"/in/{rel}"
out_path = f"/out/{rel}"
if os.path.exists(out_path):
n_skipped += 1
continue
if not os.path.exists(in_path):
n_errors += 1
continue
try:
with np.load(in_path, allow_pickle=False) as d:
stack = d["per_token_layer_activations_bf16"]
layer_indices = list(int(x) for x in d["layer_indices"])
pos = layer_indices.index(26)
l26 = stack[pos].copy()
passthrough = {
"seq_len": d["seq_len"],
"hidden_size": d["hidden_size"],
"model_name": d["model_name"],
"metadata_json": d["metadata_json"],
}
os.makedirs(os.path.dirname(out_path), exist_ok=True)
np.savez(
out_path,
layer26_activations_bf16=l26,
layer26_dtype="bfloat16",
source_layer_index=np.int32(26),
source_layer_name="blocks-26",
**passthrough,
)
total_mb += os.path.getsize(out_path) / 1e6
n_done += 1
except Exception as e:
print(f" ERROR on {rel}: {e}")
n_errors += 1
embeddings_l26_vfdb_vol.commit()
return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/in": embeddings_lean_vol},
timeout=86400,
)
def orchestrate_l26_vfdb_slice(batch_size: int = 100) -> dict:
"""Walk only /in/vfdb/, batch, fan out to slice_l26_vfdb_batch."""
import os
paths = []
for root, _, files in os.walk("/in/vfdb"):
for fname in files:
if fname.endswith(".npz"):
rel = os.path.relpath(os.path.join(root, fname), "/in")
paths.append(rel)
paths.sort()
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator-l26-vfdb] {len(paths)} lean npz → {len(batches)} batches of {batch_size}")
n_done = 0
n_skipped = 0
n_errors = 0
total_mb_out = 0.0
for i, r in enumerate(slice_l26_vfdb_batch.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
n_errors += r.get("n_errors", 0)
total_mb_out += r.get("total_mb", 0.0)
if (i + 1) % 5 == 0 or (i + 1) == len(batches):
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB")
return {
"files_total": len(paths),
"n_done": n_done,
"n_skipped": n_skipped,
"n_errors": n_errors,
"total_mb_out": total_mb_out,
}
@app.function(
image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"),
cpu=8,
volumes={"/vol": embeddings_l26_vfdb_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=86400, # 24h — first try timed out at 6h with single-worker upload
)
def upload_l26_vfdb_to_hf(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", private: bool = False,
num_workers: int = 8) -> dict:
"""Push the VFDB layer-26 slice to HF Datasets. Resumes if partial."""
import os
import time
from huggingface_hub import HfApi, login
# The huggingface Modal Secret has key/value swapped (env name is the literal token,
# value is the string "HF_TOKEN") — work around it.
token = None
for k, v in os.environ.items():
if k.startswith("hf_") or k.startswith("HF_"):
if k.startswith("hf_") and len(k) > 30:
token = k
break
if v.startswith("hf_") and len(v) > 30:
token = v
break
if not token:
token = os.environ.get("HF_TOKEN")
login(token=token)
api = HfApi()
user = api.whoami()["name"]
full_repo = f"{user}/{repo_name}"
api.create_repo(full_repo, repo_type="dataset", private=private, exist_ok=True)
print(f"[hf-push-vfdb] uploading /vol → {full_repo} (private={private}, workers={num_workers})")
t0 = time.time()
api.upload_large_folder(
folder_path="/vol",
repo_id=full_repo,
repo_type="dataset",
num_workers=num_workers,
)
elapsed = time.time() - t0
n_files = 0
bytes_total = 0
for root, _, files in os.walk("/vol"):
for fname in files:
if fname.endswith(".npz"):
n_files += 1
bytes_total += os.path.getsize(os.path.join(root, fname))
return {
"repo_url": f"https://huggingface.co/datasets/{full_repo}",
"n_files": n_files,
"bytes_total": bytes_total,
"elapsed_s": elapsed,
"private": private,
}
@app.function(
image=modal.Image.debian_slim(),
cpu=1,
volumes={
"/embeddings_lean": embeddings_lean_vol,
"/l26_vfdb": embeddings_l26_vfdb_vol,
},
timeout=3600,
)
def wipe_vfdb_outputs() -> dict:
"""Remove stale /embeddings_lean/vfdb/ and /l26_vfdb/* before a clean re-run.
Done after switching positive region_id from source_accession to vfg_id —
otherwise old files contaminate slice + HF push."""
import os
import shutil
import time
summary = {}
for path, label, vol in [
("/embeddings_lean/vfdb", "lean_vfdb", embeddings_lean_vol),
("/l26_vfdb", "l26_vfdb", embeddings_l26_vfdb_vol),
]:
t0 = time.time()
if os.path.exists(path):
n_before = sum(1 for _, _, fs in os.walk(path) for _ in fs)
# Don't actually remove the top-level mountpoint; remove its contents
for entry in os.listdir(path):
p = os.path.join(path, entry)
if os.path.isdir(p):
shutil.rmtree(p)
else:
os.remove(p)
summary[label] = {"existed": True, "files_removed": n_before, "elapsed_s": time.time() - t0}
else:
summary[label] = {"existed": False}
vol.commit()
return summary
@app.local_entrypoint()
def run_vfdb_full(repo_name: str = "mgnify-evo2-l26-vfdb-virulence",
private: bool = False,
embed_batch_size: int = 2,
slice_batch_size: int = 100,
wipe_first: bool = True):
"""One-shot: optionally wipe stale outputs, embed, slice l26, push to HF.
Use --no-wipe-first to skip the wipe (e.g., for re-runs after an interrupted push).
modal run --detach modal/evo2_inference.py::run_vfdb_full
"""
import os, time
if wipe_first:
print("[0/4] wiping stale VFDB outputs (vfdb/ and l26_vfdb/) ...")
w = wipe_vfdb_outputs.remote()
for label, info in w.items():
if info.get("existed"):
print(f" {label}: removed {info['files_removed']} files in {info['elapsed_s']:.0f} s")
else:
print(f" {label}: nothing to remove")
# 1. Upload latest VFDB JSONLs
base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready"
if not os.path.isdir(base):
raise FileNotFoundError(f"VFDB JSONLs not found at {base}")
jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl"))
print(f"\n[1/4] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for fname in jsonls:
batch.put_file(os.path.join(base, fname), f"vfdb/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
# 2. Embed (lean, 5 layers)
print(f"\n[2/4] running embed_vfdb_lean (batch_size={embed_batch_size}) ...")
r = orchestrate_vfdb_lean.remote(batch_size=embed_batch_size)
print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})")
print(f" total volume size: {r['total_mb']/1024:.1f} GB")
# 3. Slice layer 26
print(f"\n[3/4] slicing layer 26 (batch_size={slice_batch_size}) ...")
s = orchestrate_l26_vfdb_slice.remote(batch_size=slice_batch_size)
print(f" files: {s['files_total']}, done: {s['n_done']}, skipped: {s['n_skipped']}, errors: {s['n_errors']}")
print(f" l26 vol size: {s['total_mb_out']/1024:.2f} GB")
# 4. Push to HF
print(f"\n[4/4] pushing to HF as {repo_name} (private={private}) ...")
u = upload_l26_vfdb_to_hf.remote(repo_name=repo_name, private=private)
print(f"\n=== ALL DONE ===")
print(f" HF repo: {u['repo_url']}")
print(f" files: {u['n_files']}")
print(f" size: {u['bytes_total']/1e9:.2f} GB")
print(f" upload: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)")
# =============================================================
# Qualitative-sample pipeline (~20 records per "true" secondary label)
# =============================================================
# Used by Thread C in THREADS.md. ~860 records across ~61 categories,
# all positives (no matched negatives). Outputs at /embeddings_lean/qual/.
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings_lean": embeddings_lean_vol,
"/jsonl": jsonl_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=3600,
max_containers=16,
)
def embed_qual_lean(jsonl_rel_paths) -> dict:
"""Mirror of embed_vfdb_lean for the qualitative sample. Outputs at
/embeddings_lean/qual/<label_group>/<category_slug>/<region_id>.npz."""
import json, os, time
import numpy as np
import torch
if isinstance(jsonl_rel_paths, str):
jsonl_rel_paths = [jsonl_rel_paths]
t_load_start = time.time()
evo2, device, module_dict = _get_evo2_only()
t_load = time.time() - t_load_start
layer_names = [f"blocks-{i}" for i in LEAN_LAYERS]
cache: dict = {}
def make_hook(name):
def hook(module, inp, out):
cache[name] = (out[0] if isinstance(out, tuple) else out).detach()
return hook
handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names]
n_done = n_skipped = n_missing_jsonl = 0
total_mb = 0.0
per_region_times: list[float] = []
try:
for jsonl_rel in jsonl_rel_paths:
src_path = f"/jsonl/{jsonl_rel}"
if not os.path.exists(src_path):
n_missing_jsonl += 1
continue
with open(src_path) as f:
records = [json.loads(line) for line in f if line.strip()]
for rec in records:
group = rec.get("label_group") or "UNKNOWN"
slug = rec.get("mag_id") or "UNKNOWN" # mag_id field repurposed for slug
region_id = rec["region_id"]
out_dir = f"/embeddings_lean/qual/{group}/{slug}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{region_id}.npz"
if os.path.exists(out_path):
n_skipped += 1
continue
t_region = time.time()
seq = rec["sequence"]
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache[layer_names[0]].shape[1]
hidden = evo2.model.config.hidden_size
stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16)
for i, ln in enumerate(layer_names):
stack[i] = cache[ln][0].to(torch.bfloat16).cpu()
del cache[ln]
torch.cuda.empty_cache()
stack_uint16 = stack.view(torch.uint16).numpy()
meta = {k: v for k, v in rec.items() if k != "sequence"}
np.savez(
out_path,
per_token_layer_activations_bf16=stack_uint16,
per_token_layer_activations_dtype="bfloat16",
layer_names=np.array(layer_names),
layer_indices=np.array(LEAN_LAYERS, dtype=np.int32),
seq_len=np.int32(seq_len),
hidden_size=np.int32(hidden),
model_name="evo2_7b_262k",
metadata_json=np.array(json.dumps(meta)),
)
total_mb += os.path.getsize(out_path) / 1e6
n_done += 1
per_region_times.append(time.time() - t_region)
del stack, stack_uint16, input_ids
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
cache.clear()
torch.cuda.empty_cache()
t_commit_start = time.time()
embeddings_lean_vol.commit()
t_commit = time.time() - t_commit_start
return {
"n_jsonls": len(jsonl_rel_paths),
"n_missing_jsonl": n_missing_jsonl,
"n_done": n_done,
"n_skipped": n_skipped,
"total_mb": total_mb,
"model_load_s": t_load,
"commit_s": t_commit,
"mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None,
}
@app.function(
image=modal.Image.debian_slim(),
cpu=1,
volumes={"/embeddings_lean": embeddings_lean_vol},
timeout=1800,
)
def rename_qual_to_small() -> dict:
"""Rename /embeddings_lean/qual → /embeddings_lean/small in-place."""
import os, shutil, time
src = "/embeddings_lean/qual"
dst = "/embeddings_lean/small"
if not os.path.exists(src):
return {"renamed": False, "reason": f"{src} does not exist"}
if os.path.exists(dst):
return {"renamed": False, "reason": f"{dst} already exists"}
n_files = sum(1 for _, _, fs in os.walk(src) for _ in fs)
t0 = time.time()
shutil.move(src, dst)
embeddings_lean_vol.commit()
return {"renamed": True, "src": src, "dst": dst, "files_moved": n_files, "elapsed_s": time.time() - t0}
@app.local_entrypoint()
def rename_qual_small():
"""One-shot: rename qual/ → small/ on the lean volume."""
r = rename_qual_to_small.remote()
print(r)
# =============================================================
# SynGenome AMR validation pipeline (mirror of embed_vfdb_lean)
# =============================================================
# Inputs: /jsonl/syngenome/<drug_class>.jsonl (built by scripts/sample_syngenome_amr.py)
# Outputs: /embeddings_lean/syngenome/AMR/<drug_class>/<region_id>.npz
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings_lean": embeddings_lean_vol,
"/jsonl": jsonl_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=7200,
max_containers=16,
)
def embed_syngenome_lean(jsonl_rel_paths) -> dict:
"""SynGenome version of embed_vfdb_lean. All records are AMR positives."""
import json, os, time
import numpy as np
import torch
if isinstance(jsonl_rel_paths, str):
jsonl_rel_paths = [jsonl_rel_paths]
t_load_start = time.time()
evo2, device, module_dict = _get_evo2_only()
t_load = time.time() - t_load_start
layer_names = [f"blocks-{i}" for i in LEAN_LAYERS]
cache: dict = {}
def make_hook(name):
def hook(module, inp, out):
cache[name] = (out[0] if isinstance(out, tuple) else out).detach()
return hook
handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names]
n_done = n_skipped = n_missing_jsonl = 0
total_mb = 0.0
per_region_times: list[float] = []
try:
for jsonl_rel in jsonl_rel_paths:
src_path = f"/jsonl/{jsonl_rel}"
if not os.path.exists(src_path):
n_missing_jsonl += 1
continue
with open(src_path) as f:
records = [json.loads(line) for line in f if line.strip()]
for rec in records:
drug_class = rec.get("mag_id") or "UNKNOWN"
region_id = rec["region_id"]
out_dir = f"/embeddings_lean/syngenome/AMR/{drug_class}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{region_id}.npz"
if os.path.exists(out_path):
n_skipped += 1
continue
t_region = time.time()
seq = rec["sequence"]
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache[layer_names[0]].shape[1]
hidden = evo2.model.config.hidden_size
stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16)
for i, ln in enumerate(layer_names):
stack[i] = cache[ln][0].to(torch.bfloat16).cpu()
del cache[ln]
torch.cuda.empty_cache()
stack_uint16 = stack.view(torch.uint16).numpy()
meta = {k: v for k, v in rec.items() if k != "sequence"}
np.savez(
out_path,
per_token_layer_activations_bf16=stack_uint16,
per_token_layer_activations_dtype="bfloat16",
layer_names=np.array(layer_names),
layer_indices=np.array(LEAN_LAYERS, dtype=np.int32),
seq_len=np.int32(seq_len),
hidden_size=np.int32(hidden),
model_name="evo2_7b_262k",
metadata_json=np.array(json.dumps(meta)),
)
total_mb += os.path.getsize(out_path) / 1e6
n_done += 1
per_region_times.append(time.time() - t_region)
del stack, stack_uint16, input_ids
torch.cuda.empty_cache()
finally:
for h in handles:
h.remove()
cache.clear()
torch.cuda.empty_cache()
t_commit_start = time.time()
embeddings_lean_vol.commit()
t_commit = time.time() - t_commit_start
return {
"n_jsonls": len(jsonl_rel_paths),
"n_missing_jsonl": n_missing_jsonl,
"n_done": n_done,
"n_skipped": n_skipped,
"total_mb": total_mb,
"model_load_s": t_load,
"commit_s": t_commit,
"per_region_s": per_region_times,
"mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None,
}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/jsonl": jsonl_vol},
timeout=86400,
)
def orchestrate_syngenome_lean(batch_size: int = 2) -> dict:
"""Walks /jsonl/syngenome/, batches, fans out to embed_syngenome_lean."""
import os
paths = []
for root, _, files in os.walk("/jsonl/syngenome"):
for fname in files:
if fname.endswith(".jsonl"):
rel = os.path.relpath(os.path.join(root, fname), "/jsonl")
paths.append(rel)
paths.sort()
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator-syngenome] {len(paths)} JSONLs → {len(batches)} batches of up to {batch_size}")
n_done = n_skipped = errors = 0
total_mb = 0.0
region_times: list[float] = []
commit_times: list[float] = []
for i, r in enumerate(embed_syngenome_lean.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
errors += 1
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
total_mb += r.get("total_mb", 0.0)
if r.get("per_region_s"): region_times.extend(r["per_region_s"])
if r.get("commit_s") is not None: commit_times.append(r["commit_s"])
mean_t = sum(region_times) / max(len(region_times), 1)
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} "
f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f}")
return {"jsonls": len(paths), "batches": len(batches),
"regions_done": n_done, "regions_skipped": n_skipped, "errors": errors,
"total_mb": total_mb,
"mean_region_s": sum(region_times) / max(len(region_times), 1)}
# Single-layer (layer 26 only) variant — for SynGenome validation use
# (no SAE work planned; only probes consume this; layer 26 is the informative one)
embeddings_l26_syngenome_vol = modal.Volume.from_name(
"mgnify-embeddings-l26-syngenome", create_if_missing=True,
)
@app.function(
image=image,
gpu="H100",
volumes={
"/root/.cache/huggingface": weights_vol,
"/embeddings_l26_syn": embeddings_l26_syngenome_vol,
"/jsonl": jsonl_vol,
},
secrets=[modal.Secret.from_name("huggingface")],
timeout=7200,
max_containers=16,
)
def embed_syngenome_l26(jsonl_rel_paths) -> dict:
"""Layer-26-only embed for SynGenome AMRs. Same forward pass as the lean
variant but only blocks-26 is hooked + saved. ~5× less storage."""
import json, os, time
import numpy as np
import torch
if isinstance(jsonl_rel_paths, str):
jsonl_rel_paths = [jsonl_rel_paths]
t_load_start = time.time()
evo2, device, module_dict = _get_evo2_only()
t_load = time.time() - t_load_start
layer_name = "blocks-26"
cache: dict = {}
def hook(module, inp, out):
cache[layer_name] = (out[0] if isinstance(out, tuple) else out).detach()
handle = module_dict[layer_name].register_forward_hook(hook)
n_done = n_skipped = n_missing_jsonl = 0
total_mb = 0.0
per_region_times: list[float] = []
try:
for jsonl_rel in jsonl_rel_paths:
src_path = f"/jsonl/{jsonl_rel}"
if not os.path.exists(src_path):
n_missing_jsonl += 1
continue
with open(src_path) as f:
records = [json.loads(line) for line in f if line.strip()]
for rec in records:
# Top-level folder by label: positives → /AMR/, negatives → /negative/.
# Drug class / functional-class slug is the second-level grouping (mag_id field).
top = "AMR" if rec.get("is_positive", True) else "negative"
group = rec.get("mag_id") or "UNKNOWN"
region_id = rec["region_id"]
out_dir = f"/embeddings_l26_syn/{top}/{group}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{region_id}.npz"
if os.path.exists(out_path):
n_skipped += 1
continue
t_region = time.time()
seq = rec["sequence"]
cache.clear()
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
seq_len = cache[layer_name].shape[1]
hidden = evo2.model.config.hidden_size
l26 = cache[layer_name][0].to(torch.bfloat16).cpu()
del cache[layer_name]
torch.cuda.empty_cache()
l26_uint16 = l26.view(torch.uint16).numpy()
meta = {k: v for k, v in rec.items() if k != "sequence"}
np.savez(
out_path,
layer26_activations_bf16=l26_uint16,
layer26_dtype="bfloat16",
source_layer_index=np.int32(26),
source_layer_name="blocks-26",
seq_len=np.int32(seq_len),
hidden_size=np.int32(hidden),
model_name="evo2_7b_262k",
metadata_json=np.array(json.dumps(meta)),
)
total_mb += os.path.getsize(out_path) / 1e6
n_done += 1
per_region_times.append(time.time() - t_region)
del l26, l26_uint16, input_ids
torch.cuda.empty_cache()
finally:
handle.remove()
cache.clear()
torch.cuda.empty_cache()
t_commit_start = time.time()
embeddings_l26_syngenome_vol.commit()
t_commit = time.time() - t_commit_start
return {
"n_jsonls": len(jsonl_rel_paths),
"n_missing_jsonl": n_missing_jsonl,
"n_done": n_done,
"n_skipped": n_skipped,
"total_mb": total_mb,
"model_load_s": t_load,
"commit_s": t_commit,
"per_region_s": per_region_times,
"mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None,
}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/jsonl": jsonl_vol},
timeout=86400,
)
def orchestrate_syngenome_l26(batch_size: int = 2) -> dict:
"""Walks /jsonl/syngenome/, batches, fans out to embed_syngenome_l26."""
import os
paths = []
for root, _, files in os.walk("/jsonl/syngenome"):
for fname in files:
if fname.endswith(".jsonl"):
rel = os.path.relpath(os.path.join(root, fname), "/jsonl")
paths.append(rel)
paths.sort()
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator-syngenome-l26] {len(paths)} JSONLs → {len(batches)} batches of up to {batch_size}")
n_done = n_skipped = errors = 0
total_mb = 0.0
region_times: list[float] = []
commit_times: list[float] = []
for i, r in enumerate(embed_syngenome_l26.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
errors += 1
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
total_mb += r.get("total_mb", 0.0)
if r.get("per_region_s"): region_times.extend(r["per_region_s"])
if r.get("commit_s") is not None: commit_times.append(r["commit_s"])
mean_t = sum(region_times) / max(len(region_times), 1)
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} "
f"{total_mb/1024:.2f} GB mean_region_s={mean_t:.2f}")
return {"jsonls": len(paths), "batches": len(batches),
"regions_done": n_done, "regions_skipped": n_skipped, "errors": errors,
"total_mb": total_mb,
"mean_region_s": sum(region_times) / max(len(region_times), 1)}
@app.local_entrypoint()
def run_syngenome_l26(batch_size: int = 2):
"""Upload SynGenome AMR JSONLs + run layer-26-only embed."""
import os, time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome"
if not os.path.isdir(base):
raise FileNotFoundError(f"SynGenome JSONLs not found at {base}")
jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl"))
print(f"[run-syngenome-l26] uploading {len(jsonls)} JSONLs ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for fname in jsonls:
batch.put_file(os.path.join(base, fname), f"syngenome/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
print(f"\n[run-syngenome-l26] orchestrator (batch_size={batch_size})")
r = orchestrate_syngenome_l26.remote(batch_size=batch_size)
print(f"\n=== DONE ===")
print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})")
print(f" total size: {r['total_mb']/1024:.2f} GB")
print(f" mean per-region: {r['mean_region_s']:.2f} s")
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/jsonl": jsonl_vol},
timeout=86400,
)
def orchestrate_syngenome_l26_neg(batch_size: int = 2) -> dict:
"""Walks /jsonl/syngenome_neg/, batches, fans out to embed_syngenome_l26."""
import os
paths = []
for root, _, files in os.walk("/jsonl/syngenome_neg"):
for fname in files:
if fname.endswith(".jsonl"):
rel = os.path.relpath(os.path.join(root, fname), "/jsonl")
paths.append(rel)
paths.sort()
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator-syngenome-l26-neg] {len(paths)} JSONLs → {len(batches)} batches")
n_done = n_skipped = errors = 0
total_mb = 0.0
region_times: list[float] = []
for i, r in enumerate(embed_syngenome_l26.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
errors += 1
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
total_mb += r.get("total_mb", 0.0)
if r.get("per_region_s"): region_times.extend(r["per_region_s"])
mean_t = sum(region_times) / max(len(region_times), 1)
print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} "
f"{total_mb/1024:.2f} GB mean_region_s={mean_t:.2f}")
return {"jsonls": len(paths), "batches": len(batches),
"regions_done": n_done, "regions_skipped": n_skipped, "errors": errors,
"total_mb": total_mb,
"mean_region_s": sum(region_times) / max(len(region_times), 1)}
@app.local_entrypoint()
def run_syngenome_l26_neg(batch_size: int = 2):
"""Upload SynGenome NEGATIVE JSONLs + run layer-26-only embed.
Run AFTER run_syngenome_l26 finishes to avoid GPU contention."""
import os, time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome_neg"
if not os.path.isdir(base):
raise FileNotFoundError(f"SynGenome negative JSONLs not found at {base}")
jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl"))
print(f"[run-syngenome-l26-neg] uploading {len(jsonls)} JSONLs ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for fname in jsonls:
batch.put_file(os.path.join(base, fname), f"syngenome_neg/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
print(f"\n[run-syngenome-l26-neg] orchestrator (batch_size={batch_size})")
r = orchestrate_syngenome_l26_neg.remote(batch_size=batch_size)
print(f"\n=== DONE ===")
print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})")
print(f" total size: {r['total_mb']/1024:.2f} GB")
print(f" mean per-region: {r['mean_region_s']:.2f} s")
@app.local_entrypoint()
def pilot_syngenome_lean(target_records: int = 250, batch_size: int = 2):
"""Run a small SynGenome pilot to measure per-region time + commit time, then
project full-run cost. Picks drug-class JSONLs greedily up to ~target_records,
biased toward small classes for cost containment.
modal run modal/evo2_inference.py::pilot_syngenome_lean
modal run modal/evo2_inference.py::pilot_syngenome_lean --target-records 400 --batch-size 4
"""
import os, time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome"
if not os.path.isdir(base):
raise FileNotFoundError(f"SynGenome JSONLs not found at {base}; run scripts/sample_syngenome_amr.py first.")
species_files = []
for fname in sorted(os.listdir(base)):
if not fname.endswith(".jsonl"): continue
with open(os.path.join(base, fname)) as f:
n = sum(1 for line in f if line.strip())
species_files.append((n, fname))
species_files.sort() # smallest first
chosen = []
total = 0
for n, fname in species_files:
if total >= target_records:
break
if n > target_records * 2 and chosen:
continue
chosen.append((n, fname))
total += n
if not chosen:
chosen = [species_files[0]]
print(f"[pilot-syngenome] selected {len(chosen)} drug-class files:")
for n, fname in chosen:
print(f" {fname:30s} {n} records")
print(f" total pilot records: {total}")
print(f"\n[pilot-syngenome] uploading to mgnify-targeted-jsonl ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for _, fname in chosen:
batch.put_file(os.path.join(base, fname), f"syngenome/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
rel_paths = [f"syngenome/{fname}" for _, fname in chosen]
batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)]
print(f"\n[pilot-syngenome] {len(rel_paths)} JSONLs in {len(batches)} batch(es) of {batch_size}")
t0 = time.time()
results = list(embed_syngenome_lean.map(batches, return_exceptions=True))
wall = time.time() - t0
ok = [r for r in results if not isinstance(r, Exception)]
n_done = sum(r["n_done"] for r in ok)
total_mb = sum(r["total_mb"] for r in ok)
region_times = [t for r in ok for t in r.get("per_region_s") or []]
commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None]
load_times = [r["model_load_s"] for r in ok]
if not region_times:
print("ERROR: no records processed")
return
mean_region = sum(region_times) / len(region_times)
mean_commit = sum(commit_times) / max(len(commit_times), 1)
mean_load = sum(load_times) / max(len(load_times), 1)
# Avg seq_len from local JSONLs (used for projection)
import json
seq_lens_local = []
for _, fname in chosen:
with open(os.path.join(base, fname)) as f:
for line in f:
if line.strip():
seq_lens_local.append(json.loads(line).get("cds_length", 0))
pilot_avg_seqlen = sum(seq_lens_local) / max(len(seq_lens_local), 1)
print(f"\n=== SYNGENOME PILOT RESULTS ===")
print(f" records processed: {n_done}")
print(f" output size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.1f} MB/record)")
print(f" wall clock: {wall:.0f} s across {len(batches)} batch(es)")
print(f" per-region inference: {mean_region:.2f} s avg "
f"(min {min(region_times):.2f}, max {max(region_times):.2f}, p95 {sorted(region_times)[int(len(region_times)*0.95)]:.2f})")
print(f" per-region seq len: {pilot_avg_seqlen:.0f} bp avg")
print(f" per-batch commit: {mean_commit:.2f} s avg")
print(f" per-call model load: {mean_load:.1f} s avg")
# Projection to full 8000-record run
full_records = 8000
n_workers = 16
h100_rate = 4.50
full_n_jsonls = 13
full_batches = (full_n_jsonls + batch_size - 1) // batch_size
# Length-adjustment: full set has avg ~5000 bp (mostly macrolide at 5000),
# pilot biased to smaller drug classes which may have shorter sequences
full_avg_seqlen = 5000 # known cap
length_factor = full_avg_seqlen / max(pilot_avg_seqlen, 1)
inference_compute_s = full_records * mean_region * length_factor
commit_compute_s = full_batches * mean_commit
cold_start_s = mean_load * min(n_workers, full_batches)
total_compute_s = inference_compute_s + commit_compute_s + cold_start_s
wall_proj = total_compute_s / min(n_workers, full_batches)
cost = (total_compute_s / 3600) * h100_rate
output_size_gb = (total_mb / max(n_done, 1)) * full_records * length_factor / 1024
print(f"\n PROJECTION ({full_records} records, {n_workers}× H100, batch_size={batch_size}, "
f"H100=${h100_rate:.2f}/hr, length_factor={length_factor:.2f}):")
print(f" inference compute: {inference_compute_s:7.0f} s")
print(f" commit compute: {commit_compute_s:7.0f} s")
print(f" cold-start total: {cold_start_s:7.0f} s")
print(f" estimated wall clock: {wall_proj/60:5.1f} min")
print(f" estimated cost: ${cost:.2f}")
print(f" estimated total size: {output_size_gb:.1f} GB")
@app.local_entrypoint()
def run_syngenome_lean(batch_size: int = 2):
"""Upload SynGenome JSONLs + run lean embed."""
import os, time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome"
if not os.path.isdir(base):
raise FileNotFoundError(f"SynGenome JSONLs not found at {base}; run scripts/sample_syngenome_amr.py first.")
jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl"))
print(f"[run-syngenome] uploading {len(jsonls)} JSONLs ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for fname in jsonls:
batch.put_file(os.path.join(base, fname), f"syngenome/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
print(f"\n[run-syngenome] orchestrator (batch_size={batch_size})")
r = orchestrate_syngenome_lean.remote(batch_size=batch_size)
print(f"\n=== DONE ===")
print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})")
print(f" total size: {r['total_mb']/1024:.1f} GB")
print(f" mean per-region: {r['mean_region_s']:.2f} s")
# Small (qual sample) layer-26 slice + HF push — separate volume from VFDB
embeddings_l26_small_vol = modal.Volume.from_name(
"mgnify-embeddings-l26-small", create_if_missing=True,
)
@app.function(
image=modal.Image.debian_slim().pip_install("numpy"),
cpu=2,
volumes={
"/in": embeddings_lean_vol,
"/out": embeddings_l26_small_vol,
},
timeout=3600,
max_containers=8,
)
def slice_l26_small_batch(rel_paths: list[str]) -> dict:
import os
import numpy as np
n_done = n_skipped = n_errors = 0
total_mb = 0.0
for rel in rel_paths:
in_path = f"/in/{rel}"
out_path = f"/out/{rel}"
if os.path.exists(out_path):
n_skipped += 1
continue
if not os.path.exists(in_path):
n_errors += 1
continue
try:
with np.load(in_path, allow_pickle=False) as d:
stack = d["per_token_layer_activations_bf16"]
layer_indices = list(int(x) for x in d["layer_indices"])
pos = layer_indices.index(26)
l26 = stack[pos].copy()
passthrough = {
"seq_len": d["seq_len"],
"hidden_size": d["hidden_size"],
"model_name": d["model_name"],
"metadata_json": d["metadata_json"],
}
os.makedirs(os.path.dirname(out_path), exist_ok=True)
np.savez(
out_path,
layer26_activations_bf16=l26,
layer26_dtype="bfloat16",
source_layer_index=np.int32(26),
source_layer_name="blocks-26",
**passthrough,
)
total_mb += os.path.getsize(out_path) / 1e6
n_done += 1
except Exception as e:
print(f" ERROR on {rel}: {e}")
n_errors += 1
embeddings_l26_small_vol.commit()
return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb}
@app.function(
image=modal.Image.debian_slim().pip_install("modal"),
cpu=1,
volumes={"/in": embeddings_lean_vol},
timeout=86400,
)
def orchestrate_l26_small_slice(batch_size: int = 100) -> dict:
import os
paths = []
for root, _, files in os.walk("/in/small"):
for fname in files:
if fname.endswith(".npz"):
rel = os.path.relpath(os.path.join(root, fname), "/in")
paths.append(rel)
paths.sort()
batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)]
print(f"[orchestrator-l26-small] {len(paths)} npz → {len(batches)} batches")
n_done = n_skipped = n_errors = 0
total_mb_out = 0.0
for i, r in enumerate(slice_l26_small_batch.map(batches, return_exceptions=True)):
if isinstance(r, Exception):
print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}")
continue
n_done += r.get("n_done", 0)
n_skipped += r.get("n_skipped", 0)
n_errors += r.get("n_errors", 0)
total_mb_out += r.get("total_mb", 0.0)
print(f" done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB")
return {"files_total": len(paths), "n_done": n_done, "n_skipped": n_skipped,
"n_errors": n_errors, "total_mb_out": total_mb_out}
@app.function(
image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"),
cpu=4,
volumes={"/vol": embeddings_l26_small_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=21600,
)
def upload_l26_small_to_hf(repo_name: str = "mgnify-evo2-l26-small-qual",
private: bool = False) -> dict:
import os, time
from huggingface_hub import HfApi, login
token = None
for k, v in os.environ.items():
if k.startswith("hf_") and len(k) > 30:
token = k; break
if v.startswith("hf_") and len(v) > 30:
token = v; break
if not token: token = os.environ.get("HF_TOKEN")
login(token=token)
api = HfApi()
user = api.whoami()["name"]
full_repo = f"{user}/{repo_name}"
api.create_repo(full_repo, repo_type="dataset", private=private, exist_ok=True)
print(f"[hf-push-small] uploading /vol → {full_repo} (private={private})")
t0 = time.time()
api.upload_large_folder(folder_path="/vol", repo_id=full_repo, repo_type="dataset")
elapsed = time.time() - t0
n_files = bytes_total = 0
for root, _, files in os.walk("/vol"):
for fname in files:
if fname.endswith(".npz"):
n_files += 1
bytes_total += os.path.getsize(os.path.join(root, fname))
return {"repo_url": f"https://huggingface.co/datasets/{full_repo}",
"n_files": n_files, "bytes_total": bytes_total,
"elapsed_s": elapsed, "private": private}
@app.local_entrypoint()
def push_l26_small(repo_name: str = "mgnify-evo2-l26-small-qual",
private: bool = False, batch_size: int = 100):
"""Slice /embeddings_lean/small/ to layer-26 and push to HF."""
print("[1/2] slicing layer 26 from small/ ...")
s = orchestrate_l26_small_slice.remote(batch_size=batch_size)
print(f" files: {s['files_total']}, done: {s['n_done']}, skipped: {s['n_skipped']}, errors: {s['n_errors']}")
print(f" l26 size: {s['total_mb_out']/1024:.2f} GB")
print("\n[2/2] pushing to HF ...")
u = upload_l26_small_to_hf.remote(repo_name=repo_name, private=private)
print(f"\n=== DONE ===")
print(f" repo: {u['repo_url']}")
print(f" files: {u['n_files']}")
print(f" size: {u['bytes_total']/1e9:.2f} GB")
print(f" elapsed: {u['elapsed_s']:.0f} s")
@app.local_entrypoint()
def run_qual_lean(batch_size: int = 8):
"""Upload qual JSONLs + embed. Tiny job (~860 records, ~$0.30, ~2 min)."""
import os, time
base = "/home/ror25cal/MGnify/data/targeted_jsonl/qual"
if not os.path.isdir(base):
raise FileNotFoundError(f"Qual JSONLs not found at {base}; run scripts/sample_qual_jsonl.py first.")
jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl"))
print(f"[run-qual] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl ...")
t0 = time.time()
with jsonl_vol.batch_upload(force=True) as batch:
for fname in jsonls:
batch.put_file(os.path.join(base, fname), f"qual/{fname}")
print(f" uploaded in {time.time()-t0:.0f} s")
rel_paths = [f"qual/{f}" for f in jsonls]
batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)]
print(f"\n[run-qual] {len(rel_paths)} JSONLs in {len(batches)} batches of up to {batch_size}")
t0 = time.time()
results = list(embed_qual_lean.map(batches, return_exceptions=True))
wall = time.time() - t0
ok = [r for r in results if not isinstance(r, Exception)]
n_done = sum(r["n_done"] for r in ok)
n_skipped = sum(r["n_skipped"] for r in ok)
total_mb = sum(r["total_mb"] for r in ok)
print(f"\n=== QUAL DONE ===")
print(f" records embedded: {n_done} (skipped {n_skipped})")
print(f" output size: {total_mb:.0f} MB")
print(f" wall clock: {wall:.0f} s")
@app.local_entrypoint()
def push_l26_vfdb(repo_name: str = "mgnify-evo2-l26-vfdb-virulence",
private: bool = False, batch_size: int = 100):
"""Slice layer-26 from /embeddings_lean/vfdb/ then push to HF Datasets."""
print("[1/2] slicing layer 26 from VFDB lean embeddings ...")
s = orchestrate_l26_vfdb_slice.remote(batch_size=batch_size)
print(f"\n files total: {s['files_total']}")
print(f" done: {s['n_done']}")
print(f" skipped: {s['n_skipped']}")
print(f" errors: {s['n_errors']}")
print(f" l26 vol size: {s['total_mb_out']/1024:.2f} GB")
print("\n[2/2] pushing VFDB layer-26 volume to HF Datasets ...")
u = upload_l26_vfdb_to_hf.remote(repo_name=repo_name, private=private)
print(f"\n=== UPLOADED ===")
print(f" repo: {u['repo_url']}")
print(f" files: {u['n_files']}")
print(f" size: {u['bytes_total']/1e9:.2f} GB")
print(f" elapsed: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)")