""" Modal scaffold for self-hosted Evo 2-7B inference. Uses Arc Institute's official Dockerfile (from the ArcInstitute/evo2 repo), cached HF weights in a Modal Volume, and the existing "huggingface" Modal Secret (same one used by ~/AIMO3-TIR/compute/modal/). NO dependency on NVIDIA NIM / NGC / BioNeMo — weights come from HF, the container is built from Arc's Dockerfile. ONE-TIME SETUP 1. Authenticate Modal (already done): modal token new 2. HF secret (already done via AIMO3): modal secret create huggingface HF_TOKEN=... 3. Clone Arc's repo locally: git clone https://github.com/ArcInstitute/evo2.git ~/evo2_repo 4. Upload MGnify data (optional, one-time): modal volume create mgnify-data modal volume put mgnify-data /home/ror25cal/MGnify/data/ / USAGE modal run modal/evo2_inference.py::main --seq "ACGT..." --layer 26 """ import os from pathlib import Path import modal # --- paths / names -------------------------------------------------------- APP_NAME = "mgnify-evo2-7b" VOL_WEIGHTS = "evo2-7b-weights" # HF cache (weights persist here) VOL_DATA = "mgnify-data" # MGnify FASTAs (optional) TARGET_LAYER = "blocks.26.mlp.l3" # adjust as needed # --- image ----------------------------------------------------------------- # Matches Arc Institute's official Dockerfile (nvcr.io/nvidia/pytorch:25.04-py3 # + pip install evo2), translated to Modal's native Image API because Modal # doesn't support the `WORKDIR` directive from the raw Dockerfile. image = ( modal.Image.from_registry( "nvcr.io/nvidia/pytorch:25.04-py3", add_python=None, # base image already has python ) .apt_install("git", "python3-pip", "python3-tomli") .pip_install("evo2") # pulls flash-attn + vtx + huggingface_hub as transitive deps ) # --- persistent storage --------------------------------------------------- weights_vol = modal.Volume.from_name(VOL_WEIGHTS, create_if_missing=True) data_vol = modal.Volume.from_name(VOL_DATA, create_if_missing=True) app = modal.App(APP_NAME) @app.function( image=image, gpu="H100", # compute cap 9.0 — needed for Evo2's FP8 kernels (A100 is cc 8.0, fails) volumes={ "/root/.cache/huggingface": weights_vol, # HF will cache evo2_7b here, persists "/data": data_vol, # MGnify FASTAs if you uploaded them }, secrets=[modal.Secret.from_name("huggingface")], # sets HF_TOKEN env var timeout=3600, ) def embed(sequences: list[tuple[str, str]], layers: list[str] | None = None) -> dict: """ Run Evo 2-7B forward pass on a batch of (name, sequence) pairs and return per-layer embeddings. First call downloads weights into the HF-cache Volume; subsequent calls skip the download. sequences: list of (name, DNA_string) layers: list of layer names, e.g. ["blocks.26.mlp.l3"]. Default: [TARGET_LAYER] returns: {name: {layer: np.ndarray of shape [seq_len, hidden_dim]}} """ import numpy as np import torch from evo2 import Evo2 layers = layers or [TARGET_LAYER] model = Evo2("evo2_7b") # loads from /root/.cache/huggingface out = {} for name, seq in sequences: # Arc's canonical API: tokenize → int tensor → batch dim → cuda → model(..., return_embeddings=True) input_ids = torch.tensor( model.tokenizer.tokenize(seq), dtype=torch.int, ).unsqueeze(0).to("cuda:0") _, embeddings = model(input_ids, return_embeddings=True, layer_names=layers) out[name] = {lyr: np.asarray(embeddings[lyr].squeeze(0).float().cpu()) for lyr in layers} return out @app.function( image=image, gpu="H100", volumes={"/root/.cache/huggingface": weights_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=1800, ) def embed_and_sae(sequence: str, topk: int = 64) -> dict: """ Evo 2-7B-262k forward (hook at whole block-26 output) + Goodfire BatchTopK SAE. Follows Arc Institute's reference notebook: notebooks/sparse_autoencoder/sparse_autoencoder.ipynb Key differences from teammate's simple-ReLU SAE class: - model = evo2_7b_262k (the long-context variant Goodfire's SAE was trained on) - layer name = 'blocks-26' (whole block output, not blocks.26.mlp.l3) - BatchTopK=64 applied at encode — otherwise features are 4-5x too dense """ import numpy as np import torch from evo2 import Evo2 from huggingface_hub import hf_hub_download SAE_LAYER = "blocks-26" D_HIDDEN = 4096 D_SAE = D_HIDDEN * 8 K = 64 # --- load Evo 2 7B (262k context) and register a caching hook at block 26 output --- evo2 = Evo2("evo2_7b_262k") device = next(evo2.model.parameters()).device # Walk the module tree like ModelScope does; find 'blocks-26' (one module per block) module_dict = {} def recurse(m, prefix=""): for name, child in m.named_children(): module_dict[prefix + name] = child recurse(child, prefix + name + "-") recurse(evo2.model) target_module = module_dict[SAE_LAYER] cache = {} def hook_fn(module, inp, out): acts = out[0] if isinstance(out, tuple) else out cache["acts"] = acts.detach() handle = target_module.register_forward_hook(hook_fn) try: input_ids = torch.tensor(evo2.tokenizer.tokenize(sequence), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) finally: handle.remove() acts = cache["acts"][0] # [seq_len, 4096] bf16 # --- Goodfire BatchTopK SAE --- sae_path = hf_hub_download( repo_id="Goodfire/Evo-2-Layer-26-Mixed", filename="sae-layer26-mixed-expansion_8-k_64.pt", ) sae_sd = torch.load(sae_path, map_location=device, weights_only=True) # Strip the torch.compile prefix that Goodfire saved with sae_sd = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_sd.items()} W = sae_sd["W"].to(device=device, dtype=acts.dtype) # [d_hidden, d_sae] b_enc = sae_sd["b_enc"].to(device=device, dtype=acts.dtype) # Encode with BatchTopK (Arc's reference): top-K across the WHOLE batch × features, not per-token pre = torch.relu(acts @ W + b_enc) # [seq_len, d_sae] flat = pre.flatten() numel = K * pre.shape[0] topk_res = torch.topk(flat, numel, dim=-1) latents_flat = torch.zeros_like(flat).scatter(-1, topk_res.indices, topk_res.values) latents = latents_flat.reshape(pre.shape) # [seq_len, d_sae] sparse # Return only the nonzero features per position topk_vals_per_pos, topk_idx_per_pos = latents.topk(topk, dim=1) active_per_pos = (latents > 0).sum(dim=1) return { "seq_len": int(acts.shape[0]), "d_model": int(acts.shape[1]), "d_sae": int(latents.shape[1]), "model_name": "evo2_7b_262k", "sae_layer": SAE_LAYER, "topk_k": K, "topk_values": topk_vals_per_pos.float().cpu().numpy().astype(np.float32).tolist(), "topk_indices": topk_idx_per_pos.cpu().numpy().astype(np.int32).tolist(), "active_features_per_position": active_per_pos.cpu().numpy().astype(np.int32).tolist(), "activation_l2_norm": torch.linalg.norm(acts.float(), dim=1).cpu().numpy().astype(np.float32).tolist(), } # --- Embedding output volume --- embeddings_vol = modal.Volume.from_name("mgnify-embeddings", create_if_missing=True) embeddings_targeted_vol = modal.Volume.from_name("mgnify-embeddings-targeted", create_if_missing=True) jsonl_vol = modal.Volume.from_name("mgnify-targeted-jsonl", create_if_missing=True) def _get_models(): """Module-level lazy cache. Caches Evo2 + SAE weights + module_dict, but NOT hooks (hooks are re-registered per call to avoid stale-state OOM in container reuse).""" try: if _CACHED_EVO2 is not None: return _CACHED_EVO2, _CACHED_SAE_W, _CACHED_SAE_BENC, _CACHED_DEVICE, _CACHED_MODULE_DICT except NameError: pass import torch from evo2 import Evo2 from huggingface_hub import hf_hub_download print("[container] loading Evo2 7B-262k (once per container)") evo2 = Evo2("evo2_7b_262k") device = next(evo2.model.parameters()).device module_dict = {} def recurse(m, prefix=""): for n, c in m.named_children(): module_dict[prefix + n] = c recurse(c, prefix + n + "-") recurse(evo2.model) sae_path = hf_hub_download( repo_id="Goodfire/Evo-2-Layer-26-Mixed", filename="sae-layer26-mixed-expansion_8-k_64.pt", ) sae_sd = torch.load(sae_path, map_location=device, weights_only=True) sae_sd = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_sd.items()} W_sae = sae_sd["W"].to(device=device).to(torch.bfloat16) b_enc = sae_sd["b_enc"].to(device=device).to(torch.bfloat16) globals()["_CACHED_EVO2"] = evo2 globals()["_CACHED_SAE_W"] = W_sae globals()["_CACHED_SAE_BENC"] = b_enc globals()["_CACHED_DEVICE"] = device globals()["_CACHED_MODULE_DICT"] = module_dict return evo2, W_sae, b_enc, device, module_dict @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings": embeddings_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=3600, ) def embed_full( mag_id: str, contig_id: str, sequence: str, pool_size: int = 1000, chunk_size: int = 64000, # safe under Vortex FFT int32 indexing limit overlap: int = 0, # bp of overlap between consecutive chunks ) -> dict: """ Process one contig of any length, chunking as needed. Saves one /embeddings/{mag_id}/{contig_id}_{N}.npz per chunk. Reuses Evo2/SAE models across calls within the same Modal container. """ import numpy as np import torch import os K = 64 layer_names = [f"blocks-{i}" for i in range(32)] evo2, W_sae, b_enc, device, module_dict = _get_models() # Register hooks per-call so each contig starts fresh (avoids OOM accumulation) cache: dict = {} def make_hook(name): def hook(module, inp, out): cache[name] = (out[0] if isinstance(out, tuple) else out).detach() return hook handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] contig_len = len(sequence) step = chunk_size - overlap if step <= 0: raise ValueError(f"overlap ({overlap}) must be < chunk_size ({chunk_size})") # Enumerate chunks: list of (chunk_start_bp, chunk_end_bp) starts = list(range(0, contig_len, step)) chunks = [(s, min(s + chunk_size, contig_len)) for s in starts if s < contig_len] n_chunks = len(chunks) print(f"[{mag_id}/{contig_id}] contig_len={contig_len:,}, {n_chunks} chunks of ~{chunk_size:,} bp (overlap={overlap})") out_dir = f"/embeddings/{mag_id}" os.makedirs(out_dir, exist_ok=True) saved = [] # Skip chunks that already exist (idempotent re-runs) for ci, (cstart, cend) in list(enumerate(chunks)): existing = f"{out_dir}/{contig_id}_{ci}.npz" if os.path.exists(existing): saved.append({"path": existing, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": os.path.getsize(existing)/1e6, "skipped": True}) if len(saved) == len(chunks): print(f"[{mag_id}/{contig_id}] all {n_chunks} chunks already on volume, skipping") return {"mag_id": mag_id, "contig_id": contig_id, "contig_len": contig_len, "n_chunks": n_chunks, "chunks": saved, "total_size_mb": sum(s['size_mb'] for s in saved), "all_cached": True} saved = [s for s in saved if False] # reset; we re-collect below in order try: for ci, (cstart, cend) in enumerate(chunks): existing_path = f"{out_dir}/{contig_id}_{ci}.npz" if os.path.exists(existing_path): saved.append({"path": existing_path, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": os.path.getsize(existing_path)/1e6, "cached": True}) continue chunk_seq = sequence[cstart:cend] chunk_len = cend - cstart # --- forward pass on this chunk --- cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(chunk_seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache["blocks-0"].shape[1] print(f" [{ci+1}/{n_chunks}] bp {cstart:,}-{cend:,} seq_len={seq_len}") # --- BatchTopK SAE on layer 26 FIRST (so we can free GPU activations after) --- acts26 = cache["blocks-26"][0] pre = torch.relu(acts26 @ W_sae + b_enc) flat = pre.flatten() numel = K * pre.shape[0] tk = torch.topk(flat, numel, dim=-1) sparse_flat = torch.zeros_like(flat).scatter(-1, tk.indices, tk.values) latents = sparse_flat.reshape(pre.shape) top_v, top_i = latents.topk(K, dim=1) top_i_cpu = top_i.cpu().numpy().astype(np.int32) top_v_cpu = top_v.float().cpu().numpy().astype(np.float16) # Free SAE intermediates immediately del acts26, pre, flat, sparse_flat, latents, top_v, top_i, tk torch.cuda.empty_cache() # --- 1000-bp mean pooling, dropping each layer from cache as we go --- n_full = seq_len // pool_size n_windows = n_full + (1 if seq_len > n_full * pool_size else 0) # Store as bf16 (same dynamic range as fp32, no overflow on late-layer activations). # numpy lacks bf16 support, so we save the bf16 bit-pattern as uint16 and reinterpret on load. layer_means = torch.zeros(32, n_windows, evo2.model.config.hidden_size, dtype=torch.bfloat16) for i, ln in enumerate(layer_names): acts = cache[ln][0] full = acts[:n_full * pool_size].view(n_full, pool_size, -1).float().mean(dim=1) if n_full > 0 else acts.new_empty(0, acts.shape[-1]) if seq_len > n_full * pool_size: tail = acts[n_full * pool_size:].float().mean(dim=0, keepdim=True) pooled = torch.cat([full, tail], dim=0) if n_full > 0 else tail else: pooled = full layer_means[i] = pooled.to(torch.bfloat16).cpu() # IMPORTANT: drop the GPU activation for this layer to free memory del acts, full, pooled del cache[ln] torch.cuda.empty_cache() # Reinterpret bf16 buffer as uint16 for numpy storage (bit-exact, 2 bytes/value). layer_means_uint16 = layer_means.view(torch.uint16).numpy() # --- save chunk npz --- out_path = f"{out_dir}/{contig_id}_{ci}.npz" np.savez_compressed( out_path, layer_means_bf16=layer_means_uint16, # bit-pattern of bf16 in uint16 wrapper layer_means_dtype="bfloat16", # marker: how to interpret layer_means_bf16 layer_names=np.array(layer_names), pool_size=np.int32(pool_size), sae_topk_indices=top_i_cpu, sae_topk_values=top_v_cpu, sae_layer="blocks-26", seq_len=np.int32(seq_len), chunk_start=np.int64(cstart), chunk_end=np.int64(cend), chunk_idx=np.int32(ci), n_chunks=np.int32(n_chunks), chunk_size=np.int32(chunk_size), overlap=np.int32(overlap), contig_id=contig_id, contig_len=np.int64(contig_len), mag_id=mag_id, model_name="evo2_7b_262k", ) file_size = os.path.getsize(out_path) print(f" saved {file_size/1e6:.1f} MB -> {out_path}") saved.append({"path": out_path, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": file_size/1e6}) # free for next chunk (acts26+SAE intermediates already deleted above) del layer_means, layer_means_uint16, input_ids, top_i_cpu, top_v_cpu torch.cuda.empty_cache() finally: for h in handles: h.remove() # always clean up hooks cache.clear() torch.cuda.empty_cache() embeddings_vol.commit() print(f"[{mag_id}/{contig_id}] all {n_chunks} chunks saved, committed to volume") return { "mag_id": mag_id, "contig_id": contig_id, "contig_len": contig_len, "n_chunks": n_chunks, "chunks": saved, "total_size_mb": sum(s["size_mb"] for s in saved), } @app.local_entrypoint() def full_mag_test( mag_dir: str = "/home/ror25cal/MGnify/data/chicken-gut/species_catalogue/MGYG0003076/MGYG000307601/genome", mag_id: str = "MGYG000307601", contig_id: str = "MGYG000307601_21", # biggest contig (217 kb) chunk_size: int = 64000, overlap: int = 0, ): """Process one full contig, chunked. Saves /embeddings/{mag_id}/{contig_id}_{N}.npz per chunk.""" from pathlib import Path fna = Path(f"{mag_dir}/{mag_id}.fna").read_text() cur_name = None; seq_parts = [] records = {} for line in fna.splitlines(): if line.startswith(">"): if cur_name: records[cur_name] = "".join(seq_parts) cur_name = line[1:].split()[0]; seq_parts = [] else: seq_parts.append(line.strip()) if cur_name: records[cur_name] = "".join(seq_parts) seq = records[contig_id] n_expected = (len(seq) + chunk_size - overlap - 1) // (chunk_size - overlap) print(f"[{mag_id}/{contig_id}] {len(seq):,} bp → ~{n_expected} chunks @ {chunk_size:,} bp (overlap={overlap})") result = embed_full.remote(mag_id, contig_id, seq, pool_size=1000, chunk_size=chunk_size, overlap=overlap) print(f"\n=== RESULT ===") print(f" contig_len: {result['contig_len']:,}") print(f" n_chunks: {result['n_chunks']}") print(f" total_size: {result['total_size_mb']:.1f} MB") print(f" chunks:") for c in result["chunks"]: print(f" [{c['chunk_idx']}] bp {c['bp_range'][0]:,}-{c['bp_range'][1]:,} {c['size_mb']:.1f} MB → {c['path']}") @app.local_entrypoint() def run_top50_skin_amr( csv_path: str = "/home/ror25cal/MGnify/modal/top50_skin_amr.csv", skin_dir: str = "/home/ror25cal/MGnify/data/human-skin/species_catalogue", chunk_size: int = 64000, overlap: int = 0, min_contig_len: int = 5000, # skip tiny fragmented contigs — too short for meaningful gene context ): """ Process the top-50 human-skin MAGs (sorted by AMR-per-Mb density) end-to-end. One Modal call per contig, parallelism via .map() (Modal default ~10x). Container reuse + module-level model cache → model loads ~once per worker. Output: /embeddings/{mag_id}/{contig_id}_{chunk_idx}.npz on the mgnify-embeddings Volume. """ import csv from pathlib import Path # Read top-50 MAG list with open(csv_path) as f: reader = csv.DictReader(f) mag_rows = list(reader) mag_ids = [r["mag_id"] for r in mag_rows] print(f"loaded {len(mag_ids)} MAGs from {csv_path}") # Build the work list: one tuple per contig (chunking happens inside embed_full) work = [] # list of (mag_id, contig_id, sequence, pool_size, chunk_size, overlap) total_bp = 0 for mag_id in mag_ids: prefix = mag_id[:11] fna_path = Path(skin_dir) / prefix / mag_id / "genome" / f"{mag_id}.fna" if not fna_path.exists(): print(f" skipping {mag_id} — fna not found at {fna_path}") continue cur = None; parts = [] records = {} for line in fna_path.read_text().splitlines(): if line.startswith(">"): if cur: records[cur] = "".join(parts) cur = line[1:].split()[0]; parts = [] else: parts.append(line.strip()) if cur: records[cur] = "".join(parts) for cid, seq in records.items(): if len(seq) < min_contig_len: continue work.append((mag_id, cid, seq, 1000, chunk_size, overlap)) total_bp += len(seq) n_chunks_estimate = sum((len(w[2]) + chunk_size - 1) // chunk_size for w in work) # Empirical: ~$0.025 per 64kb chunk on H100 (forward + SAE + npz save) cost_estimate = n_chunks_estimate * 0.025 print(f"\nwork queue: {len(work)} contigs (≥ {min_contig_len:,} bp), {total_bp:,} total bp, ~{n_chunks_estimate} chunks") print(f"cost estimate: ~${cost_estimate:.0f}") # Submit all calls; Modal's .map() default concurrency handles the parallelism print(f"\nsubmitting to Modal...\n") n_done = 0; n_chunks_done = 0; total_mb = 0.0 for result in embed_full.starmap(work, return_exceptions=True): n_done += 1 if isinstance(result, Exception): print(f" [{n_done}/{len(work)}] ERROR: {result}") continue n_chunks_done += result.get("n_chunks", 0) total_mb += result.get("total_size_mb", 0.0) cached_tag = " (all cached)" if result.get("all_cached") else "" print(f" [{n_done}/{len(work)}] {result['mag_id']}/{result['contig_id']}: {result['n_chunks']} chunks, {result['total_size_mb']:.1f} MB{cached_tag}") print(f"\n=== DONE ===") print(f" contigs processed: {n_done}/{len(work)}") print(f" total chunks: {n_chunks_done}") print(f" total volume size: {total_mb:.1f} MB") @app.local_entrypoint() def main(seq: str = "ACGT" * 100, layer: str = TARGET_LAYER): """Smoke-test: single short sequence, one layer (just the Evo2 embed, no SAE).""" import numpy as np print(f"submitting 1 sequence of length {len(seq)} bp to {APP_NAME} @ {layer}") results = embed.remote([("smoke-test", seq)], layers=[layer]) for name, by_layer in results.items(): for lyr, arr in by_layer.items(): print(f" {name} / {lyr}: shape={arr.shape} |x|={np.abs(arr).mean():.3e}") @app.local_entrypoint() def crispr_test( region_json: str = "/home/ror25cal/MGnify/modal/crispr_test_region.json", out_path: str = "/home/ror25cal/MGnify/modal/crispr_test_result.json", ): """Run the CRISPR-region sanity test: Evo2 + Goodfire SAE, save result locally.""" import json import numpy as np region = json.loads(open(region_json).read()) seq = region["sequence"] labels = np.array(region["labels"], dtype=np.int8) print(f"region: {region['mag']} / {region['contig']} " f"bp {region['region_start']}-{region['region_end']} len={len(seq)}") print(f"label distribution: " f"{dict(zip(*np.unique(labels, return_counts=True)))} " f"(0=bg, 1=CRISPR, 2=DR, 3=spacer, 4=flank)") result = embed_and_sae.remote(seq, topk=64) print(f"\ngot back:") print(f" seq_len: {result['seq_len']}") print(f" d_model (Evo2 hidden): {result['d_model']}") print(f" d_sae (Goodfire dict): {result['d_sae']}") active = np.array(result["active_features_per_position"]) print(f" active features/pos: median={int(np.median(active))} max={int(active.max())}") # quick sanity: mean top-1 activation in CRISPR-labelled positions vs background topk_vals = np.array(result["topk_values"]) crispr_mask = labels > 0 bg_mask = labels == 0 # align lengths (Evo2 may add EOS/pad tokens) n = min(len(labels), result["seq_len"]) labels = labels[:n]; crispr_mask = crispr_mask[:n]; bg_mask = bg_mask[:n] top1 = topk_vals[:n, 0] if crispr_mask.any() and bg_mask.any(): print(f"\n top-1 SAE activation — CRISPR positions: mean={top1[crispr_mask].mean():.3f} N={crispr_mask.sum()}") print(f" top-1 SAE activation — background: mean={top1[bg_mask].mean():.3f} N={bg_mask.sum()}") print(f" ratio (CRISPR / bg): {top1[crispr_mask].mean() / max(top1[bg_mask].mean(), 1e-9):.2f}x") # Save for downstream viz result["labels"] = labels.tolist() result["region_meta"] = {k: v for k, v in region.items() if k not in ("sequence", "labels")} open(out_path, "w").write(json.dumps(result)) print(f"\nsaved result -> {out_path}") # ============================================================================ # Targeted-region per-token × all-32-layer embed pipeline # Reads JSONL records from mgnify-targeted-jsonl volume, writes per-region # npz files (per-token activations across all 32 blocks, bf16-as-uint16) to # mgnify-embeddings-targeted volume. NO SAE — raw activations only. # ============================================================================ @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings_targeted": embeddings_targeted_vol, "/jsonl": jsonl_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=3600, max_containers=16, ) def embed_targeted_jsonl(jsonl_rel_path: str) -> dict: """ Process every record in one JSONL file (one MAG × one label). Saves /embeddings_targeted/{label}/{mag_id}/{region_id}.npz containing per-token activations across all 32 blocks (bf16 stored as uint16). jsonl_rel_path: path inside the jsonl volume, e.g. "full/AMR/MGYG000516287.jsonl" """ import json import os import numpy as np import torch src_path = f"/jsonl/{jsonl_rel_path}" if not os.path.exists(src_path): return {"path": jsonl_rel_path, "error": "missing", "n_done": 0, "n_skipped": 0, "total_mb": 0.0} with open(src_path) as f: records = [json.loads(line) for line in f if line.strip()] if not records: return {"path": jsonl_rel_path, "n_done": 0, "n_skipped": 0, "total_mb": 0.0} evo2, _, _, device, module_dict = _get_models() layer_names = [f"blocks-{i}" for i in range(32)] # whole-block output, NOT blocks-{i}-mlp-l3 cache: dict = {} def make_hook(name): def hook(module, inp, out): cache[name] = (out[0] if isinstance(out, tuple) else out).detach() return hook handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] n_done = 0 n_skipped = 0 total_mb = 0.0 try: for rec in records: label_folder = rec["label"] if rec["is_positive"] else "MISC" mag_id = rec["mag_id"] region_id = rec["region_id"] out_dir = f"/embeddings_targeted/{label_folder}/{mag_id}" os.makedirs(out_dir, exist_ok=True) out_path = f"{out_dir}/{region_id}.npz" if os.path.exists(out_path): n_skipped += 1 continue seq = rec["sequence"] cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache["blocks-0"].shape[1] hidden = evo2.model.config.hidden_size # [32, seq_len, hidden] bf16 → uint16 bit-pattern for numpy storage stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) for i, ln in enumerate(layer_names): stack[i] = cache[ln][0].to(torch.bfloat16).cpu() del cache[ln] torch.cuda.empty_cache() stack_uint16 = stack.view(torch.uint16).numpy() meta = {k: v for k, v in rec.items() if k != "sequence"} np.savez_compressed( out_path, per_token_layer_activations_bf16=stack_uint16, per_token_layer_activations_dtype="bfloat16", layer_names=np.array(layer_names), seq_len=np.int32(seq_len), hidden_size=np.int32(hidden), model_name="evo2_7b_262k", metadata_json=np.array(json.dumps(meta)), ) sz = os.path.getsize(out_path) total_mb += sz / 1e6 n_done += 1 print(f" [{label_folder}/{mag_id}/{region_id}] seq_len={seq_len} saved {sz/1e6:.1f} MB") del stack, stack_uint16, input_ids torch.cuda.empty_cache() finally: for h in handles: h.remove() cache.clear() torch.cuda.empty_cache() embeddings_targeted_vol.commit() return {"path": jsonl_rel_path, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb} @app.function( image=modal.Image.debian_slim().pip_install("modal"), volumes={"/jsonl": jsonl_vol}, timeout=86400, ) def orchestrate_targeted_full() -> dict: """ CPU-only orchestrator. Walks the JSONL volume, fans every per-MAG-per-label file out to embed_targeted_jsonl via .map() and aggregates results. Living on Modal means the run survives local-process exit (use --detach). """ import os jsonl_paths = [] for root, _, files in os.walk("/jsonl/full"): for fname in files: if fname.endswith(".jsonl"): rel = os.path.relpath(os.path.join(root, fname), "/jsonl") jsonl_paths.append(rel) jsonl_paths.sort() print(f"[orchestrator] found {len(jsonl_paths)} JSONL files to process") n_total_done = 0 n_total_skipped = 0 total_mb = 0.0 errors = 0 for i, r in enumerate(embed_targeted_jsonl.map(jsonl_paths, return_exceptions=True)): if isinstance(r, Exception): errors += 1 print(f" [{i+1}/{len(jsonl_paths)}] ERROR: {r}") continue n_total_done += r.get("n_done", 0) n_total_skipped += r.get("n_skipped", 0) total_mb += r.get("total_mb", 0.0) if (i + 1) % 25 == 0 or (i + 1) == len(jsonl_paths): print(f" [{i+1}/{len(jsonl_paths)}] running totals: done={n_total_done} skipped={n_total_skipped} errors={errors} {total_mb/1024:.1f} GB") return { "jsonls": len(jsonl_paths), "regions_done": n_total_done, "regions_skipped": n_total_skipped, "errors": errors, "total_mb": total_mb, } @app.local_entrypoint() def run_targeted_full(): """ Launch the full targeted-region embed run. Use `modal run --detach` so it keeps running after the local process exits. modal run --detach modal/evo2_inference.py::run_targeted_full """ print("[local] submitting orchestrator to Modal — fan-out happens server-side") result = orchestrate_targeted_full.remote() print("\n=== DONE ===") print(f" JSONL files: {result['jsonls']}") print(f" regions saved: {result['regions_done']}") print(f" regions skipped: {result['regions_skipped']} (already on volume)") print(f" errors: {result['errors']}") print(f" total volume: {result['total_mb']/1024:.1f} GB") # ============================================================================ # Layer-26-only slicer — CPU job that reads all-32-layer npz from # mgnify-embeddings-targeted, extracts the layer-26 slice, and writes a # ~30x smaller npz to mgnify-embeddings-l26 for cheap teammate sharing. # ============================================================================ embeddings_l26_vol = modal.Volume.from_name("mgnify-embeddings-l26", create_if_missing=True) @app.function( image=modal.Image.debian_slim().pip_install("numpy"), cpu=2, volumes={ "/in": embeddings_targeted_vol, "/out": embeddings_l26_vol, }, timeout=3600, max_containers=16, ) def extract_l26_batch(rel_paths: list[str]) -> dict: """Slice layer 26 from each input npz; write a compact per-region npz to the l26 volume. rel_paths: list like ['AMR/MGYG.../MGYG..._00123_AMR.npz', ...] relative to volume root.""" import os import numpy as np n_done = 0 n_skipped = 0 n_errors = 0 total_mb_in = 0.0 total_mb_out = 0.0 for rel in rel_paths: in_path = f"/in/{rel}" out_path = f"/out/{rel}" if not os.path.exists(in_path): n_errors += 1 continue if os.path.exists(out_path): n_skipped += 1 continue try: with np.load(in_path, allow_pickle=False) as d: stack = d["per_token_layer_activations_bf16"] # uint16 [32, seq_len, 4096] l26 = stack[26].copy() # uint16 [seq_len, 4096] passthrough = { "layer_names": d["layer_names"], "seq_len": d["seq_len"], "hidden_size": d["hidden_size"], "model_name": d["model_name"], "metadata_json": d["metadata_json"], } os.makedirs(os.path.dirname(out_path), exist_ok=True) np.savez_compressed( out_path, layer26_activations_bf16=l26, # bit-pattern of bf16 stored as uint16 layer26_dtype="bfloat16", source_layer_index=np.int32(26), source_layer_name="blocks-26", **passthrough, ) total_mb_in += os.path.getsize(in_path) / 1e6 total_mb_out += os.path.getsize(out_path) / 1e6 n_done += 1 except Exception as e: print(f" ERROR on {rel}: {e}") n_errors += 1 embeddings_l26_vol.commit() return { "n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb_in": total_mb_in, "total_mb_out": total_mb_out, } @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/in": embeddings_targeted_vol}, timeout=86400, ) def orchestrate_l26_extract(batch_size: int = 50) -> dict: """List every committed all-32-layer npz on /in, batch by N, fan out to extract_l26_batch.""" import os paths = [] for root, _, files in os.walk("/in"): for fname in files: if fname.endswith(".npz"): rel = os.path.relpath(os.path.join(root, fname), "/in") paths.append(rel) paths.sort() print(f"[orchestrator] found {len(paths)} all-32-layer npz files to slice") # batch into chunks for fewer container starts and one commit per batch batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator] dispatching {len(batches)} batches of up to {batch_size}") n_done = 0 n_skipped = 0 n_errors = 0 total_mb_in = 0.0 total_mb_out = 0.0 for i, r in enumerate(extract_l26_batch.map(batches, return_exceptions=True)): if isinstance(r, Exception): print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) n_errors += r.get("n_errors", 0) total_mb_in += r.get("total_mb_in", 0.0) total_mb_out += r.get("total_mb_out", 0.0) if (i + 1) % 5 == 0 or (i + 1) == len(batches): print(f" [{i+1}/{len(batches)}] running totals: done={n_done} skipped={n_skipped} errors={n_errors} in={total_mb_in/1024:.1f} GB → out={total_mb_out/1024:.2f} GB") return { "regions_total": len(paths), "n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb_in": total_mb_in, "total_mb_out": total_mb_out, } @app.local_entrypoint() def run_l26_extract(): """ Slice layer-26 out of every committed all-32-layer npz on mgnify-embeddings-targeted, write to mgnify-embeddings-l26. modal run --detach modal/evo2_inference.py::run_l26_extract """ print("[local] submitting layer-26 slicer orchestrator to Modal") result = orchestrate_l26_extract.remote() print("\n=== DONE ===") print(f" source npz scanned: {result['regions_total']}") print(f" newly sliced: {result['n_done']}") print(f" skipped (already done): {result['n_skipped']}") print(f" errors: {result['n_errors']}") print(f" input bytes read: {result['total_mb_in']/1024:.2f} GB") print(f" output bytes written: {result['total_mb_out']/1024:.2f} GB") print(f" compression ratio: {result['total_mb_in']/max(result['total_mb_out'],1):.1f}x smaller") # ============================================================================ # Modal-side packager: tar/zip the layer-26 volume into a single file on the # same volume, then we download just that one big file (much faster than # per-file `modal volume get`, which serializes per-file requests). # ============================================================================ @app.function( image=modal.Image.debian_slim(), cpu=4, volumes={"/vol": embeddings_l26_vol}, timeout=3600, ) def pack_l26_archive(out_name: str = "embeddings_l26.zip") -> dict: """Walk /vol, write all .npz into a single zip archive (store mode — npz already gzipped) at /vol/{out_name}. Returns size + file count.""" import os import zipfile import time out_path = f"/vol/{out_name}" if os.path.exists(out_path): os.remove(out_path) n = 0 bytes_in = 0 t0 = time.time() with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_STORED) as zf: for root, _, files in os.walk("/vol"): for fname in files: if not fname.endswith(".npz"): continue full = os.path.join(root, fname) arcname = os.path.relpath(full, "/vol") zf.write(full, arcname) bytes_in += os.path.getsize(full) n += 1 if n % 50 == 0: print(f" packed {n} files, {bytes_in/1e9:.2f} GB raw...") out_size = os.path.getsize(out_path) embeddings_l26_vol.commit() return { "archive_path": out_path, "n_files": n, "bytes_raw": bytes_in, "bytes_archive": out_size, "elapsed_s": time.time() - t0, } @app.local_entrypoint() def pack_and_report(): """Pack the l26 volume into a single zip on the volume itself, ready for one-shot download.""" print("[local] packing layer-26 npz files into single zip on Modal volume...") r = pack_l26_archive.remote() print("\n=== PACKED ===") print(f" archive path on volume: {r['archive_path']}") print(f" files packed: {r['n_files']}") print(f" raw size: {r['bytes_raw']/1e9:.2f} GB") print(f" archive size: {r['bytes_archive']/1e9:.2f} GB") print(f" elapsed (server-side): {r['elapsed_s']:.0f} s") print(f"\nDownload locally with:") print(f" modal volume get mgnify-embeddings-l26 /embeddings_l26.zip /home/ror25cal/MGnify/data/embeddings_l26.zip") # ============================================================================ # Modal-side HF upload: stream the layer-26 volume directly to a Hugging Face # Dataset repo. Uses HF token from the existing 'huggingface' Modal Secret. # Avoids the local download entirely — Modal egress → HF ingress is fast. # ============================================================================ @app.function( image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), cpu=4, volumes={"/vol": embeddings_l26_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=7200, ) def upload_l26_to_hf(repo_name: str = "mgnify-evo2-l26-amr-pilot", private: bool = True) -> dict: """Push every .npz under /vol to a HF Dataset repo. Uses upload_large_folder for parallel + resumable LFS uploads.""" import os import time from huggingface_hub import HfApi token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not token: # Workaround for the legacy 'huggingface' Modal secret which was created # with key/value swapped (env var name *is* the token, value == "HF_TOKEN"). for k, v in os.environ.items(): if k.startswith("hf_") and v == "HF_TOKEN": token = k break if not token: raise RuntimeError("HF_TOKEN env var missing — check the 'huggingface' Modal Secret") api = HfApi(token=token) me = api.whoami() user = me.get("name") repo_id = f"{user}/{repo_name}" print(f"[hf] authenticated as: {user}") print(f"[hf] target repo: {repo_id} (private={private})") api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True) print(f"[hf] repo ready: https://huggingface.co/datasets/{repo_id}") # count what we're about to upload n_files = 0 bytes_total = 0 for root, _, files in os.walk("/vol"): for f in files: if f.endswith(".npz"): n_files += 1 bytes_total += os.path.getsize(os.path.join(root, f)) print(f"[hf] uploading {n_files} files, {bytes_total/1e9:.2f} GB total") t0 = time.time() api.upload_large_folder( repo_id=repo_id, repo_type="dataset", folder_path="/vol", allow_patterns=["**/*.npz"], print_report=True, ) elapsed = time.time() - t0 return { "repo_id": repo_id, "repo_url": f"https://huggingface.co/datasets/{repo_id}", "n_files": n_files, "bytes_total": bytes_total, "elapsed_s": elapsed, } @app.local_entrypoint() def push_l26_to_hf(repo_name: str = "mgnify-evo2-l26-amr-pilot", private: bool = True): """Push the layer-26 volume directly to a HF Dataset repo (no local download). modal run modal/evo2_inference.py::push_l26_to_hf modal run modal/evo2_inference.py::push_l26_to_hf --repo-name foo --no-private """ print(f"[local] launching HF push (repo={repo_name}, private={private})") r = upload_l26_to_hf.remote(repo_name=repo_name, private=private) print("\n=== UPLOADED ===") print(f" repo: {r['repo_url']}") print(f" files: {r['n_files']}") print(f" size: {r['bytes_total']/1e9:.2f} GB") print(f" elapsed: {r['elapsed_s']:.0f} s ({r['bytes_total']/1e6/max(r['elapsed_s'],1):.1f} MB/s)") # ============================================================================ # Lean targeted pipeline: 5 layers, no compression, no SAE. # Replaces embed_targeted_jsonl for cost-sensitive reruns. Saves layers # 14, 20, 24, 26, 28 as bf16-as-uint16 .npz with NO gzip — np.savez_compressed # was the dominant cost in the prior run (gzip on bf16 noise = 30s/region of # CPU while H100 idled). Uncompressed is ~3-5x faster and only ~30% bigger. # ============================================================================ embeddings_lean_vol = modal.Volume.from_name("mgnify-embeddings-lean", create_if_missing=True) LEAN_LAYERS: list[int] = [14, 20, 24, 26, 28] def _get_evo2_only(): """Lighter than _get_models — skips SAE weight load. Cached at module level.""" try: if _CACHED_EVO2_LEAN is not None: return _CACHED_EVO2_LEAN, _CACHED_DEVICE_LEAN, _CACHED_MODULE_DICT_LEAN except NameError: pass from evo2 import Evo2 print("[container] loading Evo2 7B-262k (no SAE)") evo2 = Evo2("evo2_7b_262k") device = next(evo2.model.parameters()).device module_dict = {} def recurse(m, prefix=""): for n, c in m.named_children(): module_dict[prefix + n] = c recurse(c, prefix + n + "-") recurse(evo2.model) globals()["_CACHED_EVO2_LEAN"] = evo2 globals()["_CACHED_DEVICE_LEAN"] = device globals()["_CACHED_MODULE_DICT_LEAN"] = module_dict return evo2, device, module_dict @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings_lean": embeddings_lean_vol, "/jsonl": jsonl_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=7200, max_containers=16, ) def embed_targeted_lean(jsonl_rel_paths) -> dict: """ Process all records in a *batch* of JSONL files. Single volume.commit() at end. Accepts either str (single JSONL, back-compat) or list[str] (batch). Saves /embeddings_lean/{label}/{mag}/{region}.npz per region. """ import json import os import time import numpy as np import torch if isinstance(jsonl_rel_paths, str): jsonl_rel_paths = [jsonl_rel_paths] t_load_start = time.time() evo2, device, module_dict = _get_evo2_only() t_load = time.time() - t_load_start layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] cache: dict = {} def make_hook(name): def hook(module, inp, out): cache[name] = (out[0] if isinstance(out, tuple) else out).detach() return hook handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] n_done = 0 n_skipped = 0 n_missing_jsonl = 0 total_mb = 0.0 per_region_times: list[float] = [] try: for jsonl_rel in jsonl_rel_paths: src_path = f"/jsonl/{jsonl_rel}" if not os.path.exists(src_path): n_missing_jsonl += 1 continue with open(src_path) as f: records = [json.loads(line) for line in f if line.strip()] if not records: continue for rec in records: label_folder = rec["label"] if rec["is_positive"] else "MISC" mag_id = rec["mag_id"] region_id = rec["region_id"] out_dir = f"/embeddings_lean/{label_folder}/{mag_id}" os.makedirs(out_dir, exist_ok=True) out_path = f"{out_dir}/{region_id}.npz" if os.path.exists(out_path): n_skipped += 1 continue t_region = time.time() seq = rec["sequence"] cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache[layer_names[0]].shape[1] hidden = evo2.model.config.hidden_size stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) for i, ln in enumerate(layer_names): stack[i] = cache[ln][0].to(torch.bfloat16).cpu() del cache[ln] torch.cuda.empty_cache() stack_uint16 = stack.view(torch.uint16).numpy() meta = {k: v for k, v in rec.items() if k != "sequence"} np.savez( # NO compression — gzip was the bottleneck out_path, per_token_layer_activations_bf16=stack_uint16, per_token_layer_activations_dtype="bfloat16", layer_names=np.array(layer_names), layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), seq_len=np.int32(seq_len), hidden_size=np.int32(hidden), model_name="evo2_7b_262k", metadata_json=np.array(json.dumps(meta)), ) sz = os.path.getsize(out_path) total_mb += sz / 1e6 n_done += 1 per_region_times.append(time.time() - t_region) del stack, stack_uint16, input_ids torch.cuda.empty_cache() finally: for h in handles: h.remove() cache.clear() torch.cuda.empty_cache() t_commit_start = time.time() embeddings_lean_vol.commit() t_commit = time.time() - t_commit_start return { "n_jsonls": len(jsonl_rel_paths), "n_missing_jsonl": n_missing_jsonl, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb, "model_load_s": t_load, "commit_s": t_commit, "per_region_s": per_region_times, "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, } @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/jsonl": jsonl_vol}, timeout=86400, ) def orchestrate_lean_full(batch_size: int = 50) -> dict: """Walks JSONL volume, batches paths by `batch_size`, fans out to embed_targeted_lean. Batching keeps the per-call commit overhead amortized.""" import os jsonl_paths = [] for root, _, files in os.walk("/jsonl/full"): for fname in files: if fname.endswith(".jsonl"): rel = os.path.relpath(os.path.join(root, fname), "/jsonl") jsonl_paths.append(rel) jsonl_paths.sort() batches = [jsonl_paths[i:i + batch_size] for i in range(0, len(jsonl_paths), batch_size)] print(f"[orchestrator-lean] {len(jsonl_paths)} JSONLs → {len(batches)} batches of up to {batch_size}") n_done = 0 n_skipped = 0 total_mb = 0.0 errors = 0 region_time_samples: list[float] = [] commit_time_samples: list[float] = [] for i, r in enumerate(embed_targeted_lean.map(batches, return_exceptions=True)): if isinstance(r, Exception): errors += 1 print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) total_mb += r.get("total_mb", 0.0) if r.get("per_region_s"): region_time_samples.extend(r["per_region_s"]) if r.get("commit_s") is not None: commit_time_samples.append(r["commit_s"]) if (i + 1) % 5 == 0 or (i + 1) == len(batches): mean_t = sum(region_time_samples) / max(len(region_time_samples), 1) mean_c = sum(commit_time_samples) / max(len(commit_time_samples), 1) print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f} mean_commit_s={mean_c:.2f}") return { "jsonls": len(jsonl_paths), "batches": len(batches), "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, "total_mb": total_mb, "mean_region_s": sum(region_time_samples) / max(len(region_time_samples), 1), "mean_commit_s": sum(commit_time_samples) / max(len(commit_time_samples), 1), } @app.local_entrypoint() def pilot_lean_batched(n_jsonls: int = 100, batch_size: int = 50): """Run the *batched* lean pipeline on N JSONLs to measure realistic commit overhead. Then download one output file and verify schema.""" import os import time base = "/home/ror25cal/MGnify/data/targeted_jsonl/full" paths = [] for label in ["AMR", "MISC", "VIRULENCE", "STRESS"]: d = os.path.join(base, label) if os.path.isdir(d): for fname in sorted(os.listdir(d)): if fname.endswith(".jsonl"): paths.append(f"full/{label}/{fname}") paths = paths[:n_jsonls] batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[pilot-batched] {len(paths)} JSONLs in {len(batches)} batches of up to {batch_size}") t0 = time.time() results = list(embed_targeted_lean.map(batches, return_exceptions=True)) wall = time.time() - t0 ok = [r for r in results if not isinstance(r, Exception)] n_done = sum(r["n_done"] for r in ok) total_mb = sum(r["total_mb"] for r in ok) all_region_times = [t for r in ok for t in r.get("per_region_s") or []] commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None] load_times = [r["model_load_s"] for r in ok] mean_region = sum(all_region_times) / max(len(all_region_times), 1) mean_commit = sum(commit_times) / max(len(commit_times), 1) print(f"\n=== PILOT-BATCHED SUMMARY ===") print(f" records: {n_done}") print(f" total volume size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.1f} MB/record)") print(f" wall clock (parallel): {wall:.0f} s across {len(batches)} batches × {len(results)} workers") print(f" per-region inference: {mean_region:.2f} s avg (min {min(all_region_times):.2f}, max {max(all_region_times):.2f})") print(f" per-batch commit: {mean_commit:.2f} s avg (min {min(commit_times):.2f}, max {max(commit_times):.2f})") print(f" per-call model load: {sum(load_times)/max(len(load_times),1):.1f} s avg (cold start)") # Project full run with batching full_records = 5483 full_n_jsonls = 2416 n_workers = 16 full_batches = (full_n_jsonls + batch_size - 1) // batch_size inference_compute_s = full_records * mean_region commit_compute_s = full_batches * mean_commit cold_start_s = sum(load_times) / max(len(load_times), 1) * n_workers # one cold start per worker total_compute_s = inference_compute_s + commit_compute_s + cold_start_s wall_proj = total_compute_s / n_workers cost = (total_compute_s / 3600) * 4.50 print(f"\n PROJECTION (5483 records, {n_workers}× H100, batch_size={batch_size}):") print(f" inference compute: {inference_compute_s:6.0f} s ({inference_compute_s/total_compute_s*100:.0f}%)") print(f" commit compute: {commit_compute_s:6.0f} s ({commit_compute_s/total_compute_s*100:.0f}%)") print(f" cold-start total: {cold_start_s:6.0f} s ({cold_start_s/total_compute_s*100:.0f}%)") print(f" estimated wall clock: {wall_proj/60:.1f} min") print(f" estimated cost: ${cost:.2f}") print(f" estimated total size: {total_mb/max(n_done,1) * full_records / 1024:.1f} GB") @app.local_entrypoint() def run_lean_full(batch_size: int = 50): """Launch the full lean run detached. modal run --detach modal/evo2_inference.py::run_lean_full modal run --detach modal/evo2_inference.py::run_lean_full --batch-size 50 """ print(f"[local] submitting orchestrator (batch_size={batch_size}) to Modal") r = orchestrate_lean_full.remote(batch_size=batch_size) print("\n=== DONE ===") print(f" JSONL files: {r['jsonls']} in {r['batches']} batches") print(f" regions saved: {r['regions_done']}") print(f" regions skipped: {r['regions_skipped']} (already on volume)") print(f" errors: {r['errors']}") print(f" total volume size: {r['total_mb']/1024:.1f} GB") print(f" mean per-region: {r['mean_region_s']:.2f} s") print(f" mean per-batch commit:{r['mean_commit_s']:.2f} s") @app.local_entrypoint() def pilot_lean(mag_id: str = "MGYG000516287"): """Pilot the lean pipeline on a single MAG (across all labels). Reports empirical timing + projects full-run cost. """ import time pilot_jsonls = [f"full/{label}/{mag_id}.jsonl" for label in ["AMR", "VIRULENCE", "STRESS", "MISC"]] print(f"[pilot] running on MAG {mag_id} across labels (skipping any missing)") results = [] t0 = time.time() for path in pilot_jsonls: try: r = embed_targeted_lean.remote(path) print(f" {path}: done={r['n_done']} skipped={r['n_skipped']} " f"model_load={r.get('model_load_s', 0):.1f}s " f"mean_region={r.get('mean_per_region_s') or 0:.2f}s " f"size={r['total_mb']:.1f} MB") results.append(r) except Exception as e: print(f" {path}: SKIPPED ({e})") wall = time.time() - t0 # Aggregate timings all_region_times = [t for r in results for t in r.get("per_region_s") or []] n_total = sum(r["n_done"] for r in results) bytes_total = sum(r["total_mb"] for r in results) print("\n=== PILOT SUMMARY ===") print(f" records processed: {n_total}") print(f" total volume size: {bytes_total:.1f} MB ({bytes_total/max(n_total,1):.1f} MB/record avg)") print(f" wall clock (single worker): {wall:.0f} s") if all_region_times: mean_t = sum(all_region_times) / len(all_region_times) print(f" per-region time: {mean_t:.2f} s avg " f"(min {min(all_region_times):.2f}, max {max(all_region_times):.2f})") # Project full run cost full_records = 5483 full_compute_s = full_records * mean_t # serial-equivalent compute time n_workers = 16 # Cold start ~60s per worker, contributes once each cold_start_s = 60 wall_proj = cold_start_s + full_compute_s / n_workers h100_hours = (full_compute_s / 3600 + (cold_start_s * n_workers) / 3600) cost_proj = h100_hours * 4.50 print(f"\n PROJECTION (5483 regions, {n_workers}× H100):") print(f" estimated wall clock: {wall_proj/60:.1f} min") print(f" estimated cost: ${cost_proj:.2f}") print(f" estimated total size: {bytes_total/max(n_total,1) * full_records / 1024:.1f} GB") # ============================================================================ # Slice layer 26 out of the 5-layer lean volume → new volume → push to HF. # Self-contained on Modal: parallel slicer workers + single uploader. # ============================================================================ embeddings_l26_lean_vol = modal.Volume.from_name("mgnify-embeddings-l26-lean", create_if_missing=True) @app.function( image=modal.Image.debian_slim().pip_install("numpy"), cpu=2, volumes={ "/in": embeddings_lean_vol, "/out": embeddings_l26_lean_vol, }, timeout=3600, max_containers=16, ) def slice_l26_from_lean_batch(rel_paths: list[str]) -> dict: """Slice layer 26 from each lean npz. Single commit per batch.""" import os import numpy as np n_done = 0 n_skipped = 0 n_errors = 0 total_mb = 0.0 for rel in rel_paths: in_path = f"/in/{rel}" out_path = f"/out/{rel}" if os.path.exists(out_path): n_skipped += 1 continue if not os.path.exists(in_path): n_errors += 1 continue try: with np.load(in_path, allow_pickle=False) as d: stack = d["per_token_layer_activations_bf16"] # [5, seq_len, 4096] uint16 layer_indices = list(int(x) for x in d["layer_indices"]) pos = layer_indices.index(26) l26 = stack[pos].copy() passthrough = { "seq_len": d["seq_len"], "hidden_size": d["hidden_size"], "model_name": d["model_name"], "metadata_json": d["metadata_json"], } os.makedirs(os.path.dirname(out_path), exist_ok=True) np.savez( out_path, layer26_activations_bf16=l26, layer26_dtype="bfloat16", source_layer_index=np.int32(26), source_layer_name="blocks-26", **passthrough, ) total_mb += os.path.getsize(out_path) / 1e6 n_done += 1 except Exception as e: print(f" ERROR on {rel}: {e}") n_errors += 1 embeddings_l26_lean_vol.commit() return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb} @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/in": embeddings_lean_vol}, timeout=86400, ) def orchestrate_l26_lean_slice(batch_size: int = 50) -> dict: """List lean npz files, batch them, fan out to slicer workers.""" import os paths = [] for root, _, files in os.walk("/in"): for fname in files: if fname.endswith(".npz"): rel = os.path.relpath(os.path.join(root, fname), "/in") paths.append(rel) paths.sort() batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator-l26-slice] {len(paths)} lean npz → {len(batches)} batches") n_done = 0 n_skipped = 0 n_errors = 0 total_mb_in = 0.0 total_mb_out = 0.0 for i, r in enumerate(slice_l26_from_lean_batch.map(batches, return_exceptions=True)): if isinstance(r, Exception): print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) n_errors += r.get("n_errors", 0) total_mb_out += r.get("total_mb", 0.0) if (i + 1) % 5 == 0 or (i + 1) == len(batches): print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB") return { "files_total": len(paths), "n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb_out": total_mb_out, } @app.function( image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), cpu=4, volumes={"/vol": embeddings_l26_lean_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=21600, ) def upload_l26_lean_to_hf(repo_name: str = "mgnify-evo2-l26-full", private: bool = True) -> dict: """Push the layer-26-lean volume to HF Datasets via upload_large_folder.""" import os import time from huggingface_hub import HfApi token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not token: # workaround for legacy 'huggingface' Modal secret with swapped key/value for k, v in os.environ.items(): if k.startswith("hf_") and v == "HF_TOKEN": token = k break if not token: raise RuntimeError("HF_TOKEN env var missing — check the 'huggingface' Modal Secret") api = HfApi(token=token) me = api.whoami() user = me.get("name") repo_id = f"{user}/{repo_name}" print(f"[hf] authenticated as: {user}") print(f"[hf] target repo: {repo_id} (private={private})") api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True) print(f"[hf] repo ready: https://huggingface.co/datasets/{repo_id}") n_files = 0 bytes_total = 0 for root, _, files in os.walk("/vol"): for f in files: if f.endswith(".npz"): n_files += 1 bytes_total += os.path.getsize(os.path.join(root, f)) print(f"[hf] uploading {n_files} files, {bytes_total/1e9:.2f} GB total") t0 = time.time() api.upload_large_folder( repo_id=repo_id, repo_type="dataset", folder_path="/vol", allow_patterns=["**/*.npz"], print_report=True, ) elapsed = time.time() - t0 return { "repo_id": repo_id, "repo_url": f"https://huggingface.co/datasets/{repo_id}", "n_files": n_files, "bytes_total": bytes_total, "elapsed_s": elapsed, } @app.function( image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), secrets=[modal.Secret.from_name("huggingface")], timeout=120, ) def set_hf_dataset_visibility(repo_name: str, private: bool) -> dict: """Toggle visibility of an HF dataset. Used to flip private→public after upload.""" import os from huggingface_hub import HfApi token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not token: for k, v in os.environ.items(): if k.startswith("hf_") and v == "HF_TOKEN": token = k break if not token: raise RuntimeError("HF_TOKEN missing") api = HfApi(token=token) me = api.whoami() repo_id = f"{me['name']}/{repo_name}" # newer huggingface_hub uses update_repo_settings; older has update_repo_visibility if hasattr(api, "update_repo_settings"): api.update_repo_settings(repo_id=repo_id, repo_type="dataset", private=private) else: api.update_repo_visibility(repo_id=repo_id, repo_type="dataset", private=private) info = api.repo_info(repo_id=repo_id, repo_type="dataset") return {"repo_id": repo_id, "private": getattr(info, "private", None), "url": f"https://huggingface.co/datasets/{repo_id}"} @app.local_entrypoint() def make_l26_dataset_public(repo_name: str = "mgnify-evo2-l26-full"): """Flip the layer-26 dataset to public.""" r = set_hf_dataset_visibility.remote(repo_name=repo_name, private=False) print(f" repo: {r['repo_id']}") print(f" private: {r['private']}") print(f" url: {r['url']}") @app.local_entrypoint() def push_l26_lean(repo_name: str = "mgnify-evo2-l26-full", batch_size: int = 50, private: bool = True): """Slice layer 26 from the lean volume into its own volume, then upload to HF. modal run --detach modal/evo2_inference.py::push_l26_lean modal run --detach modal/evo2_inference.py::push_l26_lean --no-private """ print("[1/2] slicing layer-26 from lean volume on Modal…") s = orchestrate_l26_lean_slice.remote(batch_size=batch_size) print(f" scanned {s['files_total']} lean npz") print(f" newly sliced: {s['n_done']}") print(f" skipped: {s['n_skipped']}") print(f" errors: {s['n_errors']}") print(f" l26 volume size: {s['total_mb_out']/1024:.2f} GB") print("\n[2/2] pushing layer-26 volume to HF Datasets…") u = upload_l26_lean_to_hf.remote(repo_name=repo_name, private=private) print(f"\n=== UPLOADED ===") print(f" repo: {u['repo_url']}") print(f" files: {u['n_files']}") print(f" size: {u['bytes_total']/1e9:.2f} GB") print(f" elapsed: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)") # ============================================================= # VFDB virulence pipeline (mirror of embed_targeted_lean) # ============================================================= # Same Evo2 forward pass + 5-layer extraction as the MGnify lean pipeline. # Differences: # - input JSONLs live at /jsonl/vfdb/.jsonl # - outputs at /embeddings_lean/vfdb/{label_folder}/{group}/{region_id}.npz # label_folder: "VIRULENCE" (positive) or "negative" (no MGnify-MISC collision) # group: species_slug (positives) or mag_id (negatives) # - record schema: positives lack mag_id natively; pre-processing in # scripts/extract_vfdb_negatives.py upload step injects species_slug. @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings_lean": embeddings_lean_vol, "/jsonl": jsonl_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=7200, max_containers=16, ) def embed_vfdb_lean(jsonl_rel_paths) -> dict: """VFDB-targeted version of embed_targeted_lean. Outputs under /embeddings_lean/vfdb/.""" import json import os import time import numpy as np import torch if isinstance(jsonl_rel_paths, str): jsonl_rel_paths = [jsonl_rel_paths] t_load_start = time.time() evo2, device, module_dict = _get_evo2_only() t_load = time.time() - t_load_start layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] cache: dict = {} def make_hook(name): def hook(module, inp, out): cache[name] = (out[0] if isinstance(out, tuple) else out).detach() return hook handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] n_done = 0 n_skipped = 0 n_missing_jsonl = 0 total_mb = 0.0 per_region_times: list[float] = [] seq_lens: list[int] = [] try: for jsonl_rel in jsonl_rel_paths: src_path = f"/jsonl/{jsonl_rel}" if not os.path.exists(src_path): n_missing_jsonl += 1 continue with open(src_path) as f: records = [json.loads(line) for line in f if line.strip()] if not records: continue for rec in records: # VFDB-aware path layout label_folder = "VIRULENCE" if rec["is_positive"] else "negative" group = rec.get("mag_id") or rec.get("species") or "UNKNOWN" region_id = rec["region_id"] out_dir = f"/embeddings_lean/vfdb/{label_folder}/{group}" os.makedirs(out_dir, exist_ok=True) out_path = f"{out_dir}/{region_id}.npz" if os.path.exists(out_path): n_skipped += 1 continue t_region = time.time() seq = rec["sequence"] cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache[layer_names[0]].shape[1] hidden = evo2.model.config.hidden_size stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) for i, ln in enumerate(layer_names): stack[i] = cache[ln][0].to(torch.bfloat16).cpu() del cache[ln] torch.cuda.empty_cache() stack_uint16 = stack.view(torch.uint16).numpy() meta = {k: v for k, v in rec.items() if k != "sequence"} np.savez( out_path, per_token_layer_activations_bf16=stack_uint16, per_token_layer_activations_dtype="bfloat16", layer_names=np.array(layer_names), layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), seq_len=np.int32(seq_len), hidden_size=np.int32(hidden), model_name="evo2_7b_262k", metadata_json=np.array(json.dumps(meta)), ) sz = os.path.getsize(out_path) total_mb += sz / 1e6 n_done += 1 per_region_times.append(time.time() - t_region) seq_lens.append(seq_len) del stack, stack_uint16, input_ids torch.cuda.empty_cache() finally: for h in handles: h.remove() cache.clear() torch.cuda.empty_cache() t_commit_start = time.time() embeddings_lean_vol.commit() t_commit = time.time() - t_commit_start return { "n_jsonls": len(jsonl_rel_paths), "n_missing_jsonl": n_missing_jsonl, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb, "model_load_s": t_load, "commit_s": t_commit, "per_region_s": per_region_times, "seq_lens": seq_lens, "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, } @app.local_entrypoint() def pilot_vfdb_lean(target_records: int = 250, batch_size: int = 2): """Run a small VFDB pilot to measure per-region time + commit time, then project full cost. Picks species files greedily to land near `target_records` total — biased toward small/medium species for cost containment, while including ≥1 medium-sized file so the per-batch commit time reflects realistic output volume. modal run modal/evo2_inference.py::pilot_vfdb_lean modal run modal/evo2_inference.py::pilot_vfdb_lean --target-records 500 --batch-size 4 """ import os, time, json base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready" if not os.path.isdir(base): raise FileNotFoundError(f"VFDB JSONLs not found at {base}; run " "extract_vfdb_negatives.py + the modal-prep step first.") species_files = [] for fname in sorted(os.listdir(base)): if not fname.endswith(".jsonl"): continue path = os.path.join(base, fname) with open(path) as f: n = sum(1 for line in f if line.strip()) species_files.append((n, fname)) species_files.sort() # smallest first # Greedy pick of small-to-medium files until we hit the target. # Skip files that would by themselves blow the budget by >2×. chosen = [] total = 0 for n, fname in species_files: if total >= target_records: break if n > target_records * 2 and chosen: continue chosen.append((n, fname)) total += n if not chosen: chosen = [species_files[0]] print(f"[pilot-vfdb] selected {len(chosen)} species:") total_records = 0 for n, fname in chosen: print(f" {fname:40s} {n} records") total_records += n print(f" total pilot records: {total_records}") # Upload selected JSONLs to jsonl_vol under vfdb/ print(f"\n[pilot-vfdb] uploading {len(chosen)} JSONLs to volume mgnify-targeted-jsonl ...") upload_t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for _, fname in chosen: local_path = os.path.join(base, fname) remote_path = f"vfdb/{fname}" batch.put_file(local_path, remote_path) print(f" uploaded in {time.time()-upload_t0:.0f} s") rel_paths = [f"vfdb/{fname}" for _, fname in chosen] batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)] print(f"\n[pilot-vfdb] {len(rel_paths)} JSONLs in {len(batches)} batch(es) of up to {batch_size}") t0 = time.time() results = list(embed_vfdb_lean.map(batches, return_exceptions=True)) wall = time.time() - t0 ok = [r for r in results if not isinstance(r, Exception)] n_done = sum(r["n_done"] for r in ok) total_mb = sum(r["total_mb"] for r in ok) region_times = [t for r in ok for t in r.get("per_region_s") or []] seq_lens = [s for r in ok for s in r.get("seq_lens") or []] commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None] load_times = [r["model_load_s"] for r in ok] if not region_times: print("ERROR: no records processed in pilot") return mean_region = sum(region_times) / len(region_times) mean_commit = sum(commit_times) / max(len(commit_times), 1) mean_load = sum(load_times) / max(len(load_times), 1) mean_seqlen = sum(seq_lens) / len(seq_lens) print(f"\n=== VFDB PILOT RESULTS ===") print(f" records processed: {n_done}") print(f" total output size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.2f} MB/record)") print(f" wall clock (parallel): {wall:.0f} s across {len(batches)} batch(es)") print(f" per-region inference: {mean_region:.2f} s avg " f"(min {min(region_times):.2f}, max {max(region_times):.2f}, p95 {sorted(region_times)[int(len(region_times)*0.95)]:.2f})") print(f" per-region seq len: {mean_seqlen:.0f} bp avg") print(f" per-batch commit: {mean_commit:.2f} s avg " f"(min {min(commit_times):.2f}, max {max(commit_times):.2f})") print(f" per-call model load: {mean_load:.1f} s avg") # Cost projection for full VFDB run full_records = 14695 n_workers = 16 h100_rate = 4.50 # $/hr — same as projection in pilot_lean_batched full_n_jsonls = 34 full_batches = (full_n_jsonls + batch_size - 1) // batch_size inference_compute_s = full_records * mean_region commit_compute_s = full_batches * mean_commit cold_start_s = mean_load * min(n_workers, full_batches) total_compute_s = inference_compute_s + commit_compute_s + cold_start_s wall_proj = total_compute_s / min(n_workers, full_batches) cost = (total_compute_s / 3600) * h100_rate output_size_gb = (total_mb / max(n_done, 1)) * full_records / 1024 print(f"\n PROJECTION ({full_records} records, {n_workers}× H100, batch_size={batch_size}, " f"H100=${h100_rate:.2f}/hr):") print(f" inference compute: {inference_compute_s:7.0f} s ({inference_compute_s/total_compute_s*100:.0f}%)") print(f" commit compute: {commit_compute_s:7.0f} s ({commit_compute_s/total_compute_s*100:.0f}%)") print(f" cold-start total: {cold_start_s:7.0f} s ({cold_start_s/total_compute_s*100:.0f}%)") print(f" estimated wall clock: {wall_proj/60:5.1f} min") print(f" estimated cost: ${cost:.2f}") print(f" estimated total size: {output_size_gb:.1f} GB") @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/jsonl": jsonl_vol}, timeout=86400, ) def orchestrate_vfdb_lean(batch_size: int = 2) -> dict: """Walks /jsonl/vfdb/, batches paths by `batch_size`, fans out to embed_vfdb_lean. Mirror of orchestrate_lean_full but for VFDB.""" import os jsonl_paths = [] for root, _, files in os.walk("/jsonl/vfdb"): for fname in files: if fname.endswith(".jsonl"): rel = os.path.relpath(os.path.join(root, fname), "/jsonl") jsonl_paths.append(rel) jsonl_paths.sort() batches = [jsonl_paths[i:i + batch_size] for i in range(0, len(jsonl_paths), batch_size)] print(f"[orchestrator-vfdb] {len(jsonl_paths)} JSONLs → {len(batches)} batches of up to {batch_size}") n_done = 0 n_skipped = 0 total_mb = 0.0 errors = 0 region_time_samples: list[float] = [] commit_time_samples: list[float] = [] for i, r in enumerate(embed_vfdb_lean.map(batches, return_exceptions=True)): if isinstance(r, Exception): errors += 1 print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) total_mb += r.get("total_mb", 0.0) if r.get("per_region_s"): region_time_samples.extend(r["per_region_s"]) if r.get("commit_s") is not None: commit_time_samples.append(r["commit_s"]) mean_t = sum(region_time_samples) / max(len(region_time_samples), 1) mean_c = sum(commit_time_samples) / max(len(commit_time_samples), 1) print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f} mean_commit_s={mean_c:.2f}") return { "jsonls": len(jsonl_paths), "batches": len(batches), "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, "total_mb": total_mb, "mean_region_s": sum(region_time_samples) / max(len(region_time_samples), 1), "mean_commit_s": sum(commit_time_samples) / max(len(commit_time_samples), 1), } @app.local_entrypoint() def run_vfdb_lean(batch_size: int = 2): """Upload all VFDB JSONLs to volume + run the full lean pipeline. modal run --detach modal/evo2_inference.py::run_vfdb_lean modal run --detach modal/evo2_inference.py::run_vfdb_lean --batch-size 4 """ import os, time base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready" if not os.path.isdir(base): raise FileNotFoundError(f"VFDB JSONLs not found at {base}") jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) print(f"[run-vfdb] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl volume ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for fname in jsonls: batch.put_file(os.path.join(base, fname), f"vfdb/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") print(f"\n[run-vfdb] launching orchestrator (batch_size={batch_size})") r = orchestrate_vfdb_lean.remote(batch_size=batch_size) print("\n=== DONE ===") print(f" JSONL files: {r['jsonls']} in {r['batches']} batches") print(f" regions saved: {r['regions_done']}") print(f" regions skipped: {r['regions_skipped']} (already on volume)") print(f" errors: {r['errors']}") print(f" total volume size: {r['total_mb']/1024:.1f} GB") print(f" mean per-region: {r['mean_region_s']:.2f} s") print(f" mean per-batch commit: {r['mean_commit_s']:.2f} s") # ============================================================= # VFDB layer-26 slice + HF push # ============================================================= # Same logic as the MGnify l26 slicer, but reads only /in/vfdb/* and writes to a # separate output volume so it doesn't mix with the MGnify l26 dataset. embeddings_l26_vfdb_vol = modal.Volume.from_name( "mgnify-embeddings-l26-vfdb", create_if_missing=True, ) @app.function( image=modal.Image.debian_slim().pip_install("numpy"), cpu=2, volumes={ "/in": embeddings_lean_vol, "/out": embeddings_l26_vfdb_vol, }, timeout=3600, max_containers=16, ) def slice_l26_vfdb_batch(rel_paths: list[str]) -> dict: """Slice layer 26 from each VFDB lean npz. Same schema as slice_l26_from_lean_batch.""" import os import numpy as np n_done = 0 n_skipped = 0 n_errors = 0 total_mb = 0.0 for rel in rel_paths: in_path = f"/in/{rel}" out_path = f"/out/{rel}" if os.path.exists(out_path): n_skipped += 1 continue if not os.path.exists(in_path): n_errors += 1 continue try: with np.load(in_path, allow_pickle=False) as d: stack = d["per_token_layer_activations_bf16"] layer_indices = list(int(x) for x in d["layer_indices"]) pos = layer_indices.index(26) l26 = stack[pos].copy() passthrough = { "seq_len": d["seq_len"], "hidden_size": d["hidden_size"], "model_name": d["model_name"], "metadata_json": d["metadata_json"], } os.makedirs(os.path.dirname(out_path), exist_ok=True) np.savez( out_path, layer26_activations_bf16=l26, layer26_dtype="bfloat16", source_layer_index=np.int32(26), source_layer_name="blocks-26", **passthrough, ) total_mb += os.path.getsize(out_path) / 1e6 n_done += 1 except Exception as e: print(f" ERROR on {rel}: {e}") n_errors += 1 embeddings_l26_vfdb_vol.commit() return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb} @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/in": embeddings_lean_vol}, timeout=86400, ) def orchestrate_l26_vfdb_slice(batch_size: int = 100) -> dict: """Walk only /in/vfdb/, batch, fan out to slice_l26_vfdb_batch.""" import os paths = [] for root, _, files in os.walk("/in/vfdb"): for fname in files: if fname.endswith(".npz"): rel = os.path.relpath(os.path.join(root, fname), "/in") paths.append(rel) paths.sort() batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator-l26-vfdb] {len(paths)} lean npz → {len(batches)} batches of {batch_size}") n_done = 0 n_skipped = 0 n_errors = 0 total_mb_out = 0.0 for i, r in enumerate(slice_l26_vfdb_batch.map(batches, return_exceptions=True)): if isinstance(r, Exception): print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) n_errors += r.get("n_errors", 0) total_mb_out += r.get("total_mb", 0.0) if (i + 1) % 5 == 0 or (i + 1) == len(batches): print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB") return { "files_total": len(paths), "n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb_out": total_mb_out, } @app.function( image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), cpu=8, volumes={"/vol": embeddings_l26_vfdb_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=86400, # 24h — first try timed out at 6h with single-worker upload ) def upload_l26_vfdb_to_hf(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", private: bool = False, num_workers: int = 8) -> dict: """Push the VFDB layer-26 slice to HF Datasets. Resumes if partial.""" import os import time from huggingface_hub import HfApi, login # The huggingface Modal Secret has key/value swapped (env name is the literal token, # value is the string "HF_TOKEN") — work around it. token = None for k, v in os.environ.items(): if k.startswith("hf_") or k.startswith("HF_"): if k.startswith("hf_") and len(k) > 30: token = k break if v.startswith("hf_") and len(v) > 30: token = v break if not token: token = os.environ.get("HF_TOKEN") login(token=token) api = HfApi() user = api.whoami()["name"] full_repo = f"{user}/{repo_name}" api.create_repo(full_repo, repo_type="dataset", private=private, exist_ok=True) print(f"[hf-push-vfdb] uploading /vol → {full_repo} (private={private}, workers={num_workers})") t0 = time.time() api.upload_large_folder( folder_path="/vol", repo_id=full_repo, repo_type="dataset", num_workers=num_workers, ) elapsed = time.time() - t0 n_files = 0 bytes_total = 0 for root, _, files in os.walk("/vol"): for fname in files: if fname.endswith(".npz"): n_files += 1 bytes_total += os.path.getsize(os.path.join(root, fname)) return { "repo_url": f"https://huggingface.co/datasets/{full_repo}", "n_files": n_files, "bytes_total": bytes_total, "elapsed_s": elapsed, "private": private, } @app.function( image=modal.Image.debian_slim(), cpu=1, volumes={ "/embeddings_lean": embeddings_lean_vol, "/l26_vfdb": embeddings_l26_vfdb_vol, }, timeout=3600, ) def wipe_vfdb_outputs() -> dict: """Remove stale /embeddings_lean/vfdb/ and /l26_vfdb/* before a clean re-run. Done after switching positive region_id from source_accession to vfg_id — otherwise old files contaminate slice + HF push.""" import os import shutil import time summary = {} for path, label, vol in [ ("/embeddings_lean/vfdb", "lean_vfdb", embeddings_lean_vol), ("/l26_vfdb", "l26_vfdb", embeddings_l26_vfdb_vol), ]: t0 = time.time() if os.path.exists(path): n_before = sum(1 for _, _, fs in os.walk(path) for _ in fs) # Don't actually remove the top-level mountpoint; remove its contents for entry in os.listdir(path): p = os.path.join(path, entry) if os.path.isdir(p): shutil.rmtree(p) else: os.remove(p) summary[label] = {"existed": True, "files_removed": n_before, "elapsed_s": time.time() - t0} else: summary[label] = {"existed": False} vol.commit() return summary @app.local_entrypoint() def run_vfdb_full(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", private: bool = False, embed_batch_size: int = 2, slice_batch_size: int = 100, wipe_first: bool = True): """One-shot: optionally wipe stale outputs, embed, slice l26, push to HF. Use --no-wipe-first to skip the wipe (e.g., for re-runs after an interrupted push). modal run --detach modal/evo2_inference.py::run_vfdb_full """ import os, time if wipe_first: print("[0/4] wiping stale VFDB outputs (vfdb/ and l26_vfdb/) ...") w = wipe_vfdb_outputs.remote() for label, info in w.items(): if info.get("existed"): print(f" {label}: removed {info['files_removed']} files in {info['elapsed_s']:.0f} s") else: print(f" {label}: nothing to remove") # 1. Upload latest VFDB JSONLs base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready" if not os.path.isdir(base): raise FileNotFoundError(f"VFDB JSONLs not found at {base}") jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) print(f"\n[1/4] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for fname in jsonls: batch.put_file(os.path.join(base, fname), f"vfdb/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") # 2. Embed (lean, 5 layers) print(f"\n[2/4] running embed_vfdb_lean (batch_size={embed_batch_size}) ...") r = orchestrate_vfdb_lean.remote(batch_size=embed_batch_size) print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") print(f" total volume size: {r['total_mb']/1024:.1f} GB") # 3. Slice layer 26 print(f"\n[3/4] slicing layer 26 (batch_size={slice_batch_size}) ...") s = orchestrate_l26_vfdb_slice.remote(batch_size=slice_batch_size) print(f" files: {s['files_total']}, done: {s['n_done']}, skipped: {s['n_skipped']}, errors: {s['n_errors']}") print(f" l26 vol size: {s['total_mb_out']/1024:.2f} GB") # 4. Push to HF print(f"\n[4/4] pushing to HF as {repo_name} (private={private}) ...") u = upload_l26_vfdb_to_hf.remote(repo_name=repo_name, private=private) print(f"\n=== ALL DONE ===") print(f" HF repo: {u['repo_url']}") print(f" files: {u['n_files']}") print(f" size: {u['bytes_total']/1e9:.2f} GB") print(f" upload: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)") # ============================================================= # Qualitative-sample pipeline (~20 records per "true" secondary label) # ============================================================= # Used by Thread C in THREADS.md. ~860 records across ~61 categories, # all positives (no matched negatives). Outputs at /embeddings_lean/qual/. @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings_lean": embeddings_lean_vol, "/jsonl": jsonl_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=3600, max_containers=16, ) def embed_qual_lean(jsonl_rel_paths) -> dict: """Mirror of embed_vfdb_lean for the qualitative sample. Outputs at /embeddings_lean/qual///.npz.""" import json, os, time import numpy as np import torch if isinstance(jsonl_rel_paths, str): jsonl_rel_paths = [jsonl_rel_paths] t_load_start = time.time() evo2, device, module_dict = _get_evo2_only() t_load = time.time() - t_load_start layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] cache: dict = {} def make_hook(name): def hook(module, inp, out): cache[name] = (out[0] if isinstance(out, tuple) else out).detach() return hook handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] n_done = n_skipped = n_missing_jsonl = 0 total_mb = 0.0 per_region_times: list[float] = [] try: for jsonl_rel in jsonl_rel_paths: src_path = f"/jsonl/{jsonl_rel}" if not os.path.exists(src_path): n_missing_jsonl += 1 continue with open(src_path) as f: records = [json.loads(line) for line in f if line.strip()] for rec in records: group = rec.get("label_group") or "UNKNOWN" slug = rec.get("mag_id") or "UNKNOWN" # mag_id field repurposed for slug region_id = rec["region_id"] out_dir = f"/embeddings_lean/qual/{group}/{slug}" os.makedirs(out_dir, exist_ok=True) out_path = f"{out_dir}/{region_id}.npz" if os.path.exists(out_path): n_skipped += 1 continue t_region = time.time() seq = rec["sequence"] cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache[layer_names[0]].shape[1] hidden = evo2.model.config.hidden_size stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) for i, ln in enumerate(layer_names): stack[i] = cache[ln][0].to(torch.bfloat16).cpu() del cache[ln] torch.cuda.empty_cache() stack_uint16 = stack.view(torch.uint16).numpy() meta = {k: v for k, v in rec.items() if k != "sequence"} np.savez( out_path, per_token_layer_activations_bf16=stack_uint16, per_token_layer_activations_dtype="bfloat16", layer_names=np.array(layer_names), layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), seq_len=np.int32(seq_len), hidden_size=np.int32(hidden), model_name="evo2_7b_262k", metadata_json=np.array(json.dumps(meta)), ) total_mb += os.path.getsize(out_path) / 1e6 n_done += 1 per_region_times.append(time.time() - t_region) del stack, stack_uint16, input_ids torch.cuda.empty_cache() finally: for h in handles: h.remove() cache.clear() torch.cuda.empty_cache() t_commit_start = time.time() embeddings_lean_vol.commit() t_commit = time.time() - t_commit_start return { "n_jsonls": len(jsonl_rel_paths), "n_missing_jsonl": n_missing_jsonl, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb, "model_load_s": t_load, "commit_s": t_commit, "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, } @app.function( image=modal.Image.debian_slim(), cpu=1, volumes={"/embeddings_lean": embeddings_lean_vol}, timeout=1800, ) def rename_qual_to_small() -> dict: """Rename /embeddings_lean/qual → /embeddings_lean/small in-place.""" import os, shutil, time src = "/embeddings_lean/qual" dst = "/embeddings_lean/small" if not os.path.exists(src): return {"renamed": False, "reason": f"{src} does not exist"} if os.path.exists(dst): return {"renamed": False, "reason": f"{dst} already exists"} n_files = sum(1 for _, _, fs in os.walk(src) for _ in fs) t0 = time.time() shutil.move(src, dst) embeddings_lean_vol.commit() return {"renamed": True, "src": src, "dst": dst, "files_moved": n_files, "elapsed_s": time.time() - t0} @app.local_entrypoint() def rename_qual_small(): """One-shot: rename qual/ → small/ on the lean volume.""" r = rename_qual_to_small.remote() print(r) # ============================================================= # SynGenome AMR validation pipeline (mirror of embed_vfdb_lean) # ============================================================= # Inputs: /jsonl/syngenome/.jsonl (built by scripts/sample_syngenome_amr.py) # Outputs: /embeddings_lean/syngenome/AMR//.npz @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings_lean": embeddings_lean_vol, "/jsonl": jsonl_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=7200, max_containers=16, ) def embed_syngenome_lean(jsonl_rel_paths) -> dict: """SynGenome version of embed_vfdb_lean. All records are AMR positives.""" import json, os, time import numpy as np import torch if isinstance(jsonl_rel_paths, str): jsonl_rel_paths = [jsonl_rel_paths] t_load_start = time.time() evo2, device, module_dict = _get_evo2_only() t_load = time.time() - t_load_start layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] cache: dict = {} def make_hook(name): def hook(module, inp, out): cache[name] = (out[0] if isinstance(out, tuple) else out).detach() return hook handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] n_done = n_skipped = n_missing_jsonl = 0 total_mb = 0.0 per_region_times: list[float] = [] try: for jsonl_rel in jsonl_rel_paths: src_path = f"/jsonl/{jsonl_rel}" if not os.path.exists(src_path): n_missing_jsonl += 1 continue with open(src_path) as f: records = [json.loads(line) for line in f if line.strip()] for rec in records: drug_class = rec.get("mag_id") or "UNKNOWN" region_id = rec["region_id"] out_dir = f"/embeddings_lean/syngenome/AMR/{drug_class}" os.makedirs(out_dir, exist_ok=True) out_path = f"{out_dir}/{region_id}.npz" if os.path.exists(out_path): n_skipped += 1 continue t_region = time.time() seq = rec["sequence"] cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache[layer_names[0]].shape[1] hidden = evo2.model.config.hidden_size stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) for i, ln in enumerate(layer_names): stack[i] = cache[ln][0].to(torch.bfloat16).cpu() del cache[ln] torch.cuda.empty_cache() stack_uint16 = stack.view(torch.uint16).numpy() meta = {k: v for k, v in rec.items() if k != "sequence"} np.savez( out_path, per_token_layer_activations_bf16=stack_uint16, per_token_layer_activations_dtype="bfloat16", layer_names=np.array(layer_names), layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), seq_len=np.int32(seq_len), hidden_size=np.int32(hidden), model_name="evo2_7b_262k", metadata_json=np.array(json.dumps(meta)), ) total_mb += os.path.getsize(out_path) / 1e6 n_done += 1 per_region_times.append(time.time() - t_region) del stack, stack_uint16, input_ids torch.cuda.empty_cache() finally: for h in handles: h.remove() cache.clear() torch.cuda.empty_cache() t_commit_start = time.time() embeddings_lean_vol.commit() t_commit = time.time() - t_commit_start return { "n_jsonls": len(jsonl_rel_paths), "n_missing_jsonl": n_missing_jsonl, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb, "model_load_s": t_load, "commit_s": t_commit, "per_region_s": per_region_times, "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, } @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/jsonl": jsonl_vol}, timeout=86400, ) def orchestrate_syngenome_lean(batch_size: int = 2) -> dict: """Walks /jsonl/syngenome/, batches, fans out to embed_syngenome_lean.""" import os paths = [] for root, _, files in os.walk("/jsonl/syngenome"): for fname in files: if fname.endswith(".jsonl"): rel = os.path.relpath(os.path.join(root, fname), "/jsonl") paths.append(rel) paths.sort() batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator-syngenome] {len(paths)} JSONLs → {len(batches)} batches of up to {batch_size}") n_done = n_skipped = errors = 0 total_mb = 0.0 region_times: list[float] = [] commit_times: list[float] = [] for i, r in enumerate(embed_syngenome_lean.map(batches, return_exceptions=True)): if isinstance(r, Exception): errors += 1 print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) total_mb += r.get("total_mb", 0.0) if r.get("per_region_s"): region_times.extend(r["per_region_s"]) if r.get("commit_s") is not None: commit_times.append(r["commit_s"]) mean_t = sum(region_times) / max(len(region_times), 1) print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f}") return {"jsonls": len(paths), "batches": len(batches), "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, "total_mb": total_mb, "mean_region_s": sum(region_times) / max(len(region_times), 1)} # Single-layer (layer 26 only) variant — for SynGenome validation use # (no SAE work planned; only probes consume this; layer 26 is the informative one) embeddings_l26_syngenome_vol = modal.Volume.from_name( "mgnify-embeddings-l26-syngenome", create_if_missing=True, ) @app.function( image=image, gpu="H100", volumes={ "/root/.cache/huggingface": weights_vol, "/embeddings_l26_syn": embeddings_l26_syngenome_vol, "/jsonl": jsonl_vol, }, secrets=[modal.Secret.from_name("huggingface")], timeout=7200, max_containers=16, ) def embed_syngenome_l26(jsonl_rel_paths) -> dict: """Layer-26-only embed for SynGenome AMRs. Same forward pass as the lean variant but only blocks-26 is hooked + saved. ~5× less storage.""" import json, os, time import numpy as np import torch if isinstance(jsonl_rel_paths, str): jsonl_rel_paths = [jsonl_rel_paths] t_load_start = time.time() evo2, device, module_dict = _get_evo2_only() t_load = time.time() - t_load_start layer_name = "blocks-26" cache: dict = {} def hook(module, inp, out): cache[layer_name] = (out[0] if isinstance(out, tuple) else out).detach() handle = module_dict[layer_name].register_forward_hook(hook) n_done = n_skipped = n_missing_jsonl = 0 total_mb = 0.0 per_region_times: list[float] = [] try: for jsonl_rel in jsonl_rel_paths: src_path = f"/jsonl/{jsonl_rel}" if not os.path.exists(src_path): n_missing_jsonl += 1 continue with open(src_path) as f: records = [json.loads(line) for line in f if line.strip()] for rec in records: # Top-level folder by label: positives → /AMR/, negatives → /negative/. # Drug class / functional-class slug is the second-level grouping (mag_id field). top = "AMR" if rec.get("is_positive", True) else "negative" group = rec.get("mag_id") or "UNKNOWN" region_id = rec["region_id"] out_dir = f"/embeddings_l26_syn/{top}/{group}" os.makedirs(out_dir, exist_ok=True) out_path = f"{out_dir}/{region_id}.npz" if os.path.exists(out_path): n_skipped += 1 continue t_region = time.time() seq = rec["sequence"] cache.clear() input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) seq_len = cache[layer_name].shape[1] hidden = evo2.model.config.hidden_size l26 = cache[layer_name][0].to(torch.bfloat16).cpu() del cache[layer_name] torch.cuda.empty_cache() l26_uint16 = l26.view(torch.uint16).numpy() meta = {k: v for k, v in rec.items() if k != "sequence"} np.savez( out_path, layer26_activations_bf16=l26_uint16, layer26_dtype="bfloat16", source_layer_index=np.int32(26), source_layer_name="blocks-26", seq_len=np.int32(seq_len), hidden_size=np.int32(hidden), model_name="evo2_7b_262k", metadata_json=np.array(json.dumps(meta)), ) total_mb += os.path.getsize(out_path) / 1e6 n_done += 1 per_region_times.append(time.time() - t_region) del l26, l26_uint16, input_ids torch.cuda.empty_cache() finally: handle.remove() cache.clear() torch.cuda.empty_cache() t_commit_start = time.time() embeddings_l26_syngenome_vol.commit() t_commit = time.time() - t_commit_start return { "n_jsonls": len(jsonl_rel_paths), "n_missing_jsonl": n_missing_jsonl, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb, "model_load_s": t_load, "commit_s": t_commit, "per_region_s": per_region_times, "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, } @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/jsonl": jsonl_vol}, timeout=86400, ) def orchestrate_syngenome_l26(batch_size: int = 2) -> dict: """Walks /jsonl/syngenome/, batches, fans out to embed_syngenome_l26.""" import os paths = [] for root, _, files in os.walk("/jsonl/syngenome"): for fname in files: if fname.endswith(".jsonl"): rel = os.path.relpath(os.path.join(root, fname), "/jsonl") paths.append(rel) paths.sort() batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator-syngenome-l26] {len(paths)} JSONLs → {len(batches)} batches of up to {batch_size}") n_done = n_skipped = errors = 0 total_mb = 0.0 region_times: list[float] = [] commit_times: list[float] = [] for i, r in enumerate(embed_syngenome_l26.map(batches, return_exceptions=True)): if isinstance(r, Exception): errors += 1 print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) total_mb += r.get("total_mb", 0.0) if r.get("per_region_s"): region_times.extend(r["per_region_s"]) if r.get("commit_s") is not None: commit_times.append(r["commit_s"]) mean_t = sum(region_times) / max(len(region_times), 1) print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " f"{total_mb/1024:.2f} GB mean_region_s={mean_t:.2f}") return {"jsonls": len(paths), "batches": len(batches), "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, "total_mb": total_mb, "mean_region_s": sum(region_times) / max(len(region_times), 1)} @app.local_entrypoint() def run_syngenome_l26(batch_size: int = 2): """Upload SynGenome AMR JSONLs + run layer-26-only embed.""" import os, time base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome" if not os.path.isdir(base): raise FileNotFoundError(f"SynGenome JSONLs not found at {base}") jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) print(f"[run-syngenome-l26] uploading {len(jsonls)} JSONLs ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for fname in jsonls: batch.put_file(os.path.join(base, fname), f"syngenome/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") print(f"\n[run-syngenome-l26] orchestrator (batch_size={batch_size})") r = orchestrate_syngenome_l26.remote(batch_size=batch_size) print(f"\n=== DONE ===") print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") print(f" total size: {r['total_mb']/1024:.2f} GB") print(f" mean per-region: {r['mean_region_s']:.2f} s") @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/jsonl": jsonl_vol}, timeout=86400, ) def orchestrate_syngenome_l26_neg(batch_size: int = 2) -> dict: """Walks /jsonl/syngenome_neg/, batches, fans out to embed_syngenome_l26.""" import os paths = [] for root, _, files in os.walk("/jsonl/syngenome_neg"): for fname in files: if fname.endswith(".jsonl"): rel = os.path.relpath(os.path.join(root, fname), "/jsonl") paths.append(rel) paths.sort() batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator-syngenome-l26-neg] {len(paths)} JSONLs → {len(batches)} batches") n_done = n_skipped = errors = 0 total_mb = 0.0 region_times: list[float] = [] for i, r in enumerate(embed_syngenome_l26.map(batches, return_exceptions=True)): if isinstance(r, Exception): errors += 1 print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) total_mb += r.get("total_mb", 0.0) if r.get("per_region_s"): region_times.extend(r["per_region_s"]) mean_t = sum(region_times) / max(len(region_times), 1) print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " f"{total_mb/1024:.2f} GB mean_region_s={mean_t:.2f}") return {"jsonls": len(paths), "batches": len(batches), "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, "total_mb": total_mb, "mean_region_s": sum(region_times) / max(len(region_times), 1)} @app.local_entrypoint() def run_syngenome_l26_neg(batch_size: int = 2): """Upload SynGenome NEGATIVE JSONLs + run layer-26-only embed. Run AFTER run_syngenome_l26 finishes to avoid GPU contention.""" import os, time base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome_neg" if not os.path.isdir(base): raise FileNotFoundError(f"SynGenome negative JSONLs not found at {base}") jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) print(f"[run-syngenome-l26-neg] uploading {len(jsonls)} JSONLs ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for fname in jsonls: batch.put_file(os.path.join(base, fname), f"syngenome_neg/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") print(f"\n[run-syngenome-l26-neg] orchestrator (batch_size={batch_size})") r = orchestrate_syngenome_l26_neg.remote(batch_size=batch_size) print(f"\n=== DONE ===") print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") print(f" total size: {r['total_mb']/1024:.2f} GB") print(f" mean per-region: {r['mean_region_s']:.2f} s") @app.local_entrypoint() def pilot_syngenome_lean(target_records: int = 250, batch_size: int = 2): """Run a small SynGenome pilot to measure per-region time + commit time, then project full-run cost. Picks drug-class JSONLs greedily up to ~target_records, biased toward small classes for cost containment. modal run modal/evo2_inference.py::pilot_syngenome_lean modal run modal/evo2_inference.py::pilot_syngenome_lean --target-records 400 --batch-size 4 """ import os, time base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome" if not os.path.isdir(base): raise FileNotFoundError(f"SynGenome JSONLs not found at {base}; run scripts/sample_syngenome_amr.py first.") species_files = [] for fname in sorted(os.listdir(base)): if not fname.endswith(".jsonl"): continue with open(os.path.join(base, fname)) as f: n = sum(1 for line in f if line.strip()) species_files.append((n, fname)) species_files.sort() # smallest first chosen = [] total = 0 for n, fname in species_files: if total >= target_records: break if n > target_records * 2 and chosen: continue chosen.append((n, fname)) total += n if not chosen: chosen = [species_files[0]] print(f"[pilot-syngenome] selected {len(chosen)} drug-class files:") for n, fname in chosen: print(f" {fname:30s} {n} records") print(f" total pilot records: {total}") print(f"\n[pilot-syngenome] uploading to mgnify-targeted-jsonl ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for _, fname in chosen: batch.put_file(os.path.join(base, fname), f"syngenome/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") rel_paths = [f"syngenome/{fname}" for _, fname in chosen] batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)] print(f"\n[pilot-syngenome] {len(rel_paths)} JSONLs in {len(batches)} batch(es) of {batch_size}") t0 = time.time() results = list(embed_syngenome_lean.map(batches, return_exceptions=True)) wall = time.time() - t0 ok = [r for r in results if not isinstance(r, Exception)] n_done = sum(r["n_done"] for r in ok) total_mb = sum(r["total_mb"] for r in ok) region_times = [t for r in ok for t in r.get("per_region_s") or []] commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None] load_times = [r["model_load_s"] for r in ok] if not region_times: print("ERROR: no records processed") return mean_region = sum(region_times) / len(region_times) mean_commit = sum(commit_times) / max(len(commit_times), 1) mean_load = sum(load_times) / max(len(load_times), 1) # Avg seq_len from local JSONLs (used for projection) import json seq_lens_local = [] for _, fname in chosen: with open(os.path.join(base, fname)) as f: for line in f: if line.strip(): seq_lens_local.append(json.loads(line).get("cds_length", 0)) pilot_avg_seqlen = sum(seq_lens_local) / max(len(seq_lens_local), 1) print(f"\n=== SYNGENOME PILOT RESULTS ===") print(f" records processed: {n_done}") print(f" output size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.1f} MB/record)") print(f" wall clock: {wall:.0f} s across {len(batches)} batch(es)") print(f" per-region inference: {mean_region:.2f} s avg " f"(min {min(region_times):.2f}, max {max(region_times):.2f}, p95 {sorted(region_times)[int(len(region_times)*0.95)]:.2f})") print(f" per-region seq len: {pilot_avg_seqlen:.0f} bp avg") print(f" per-batch commit: {mean_commit:.2f} s avg") print(f" per-call model load: {mean_load:.1f} s avg") # Projection to full 8000-record run full_records = 8000 n_workers = 16 h100_rate = 4.50 full_n_jsonls = 13 full_batches = (full_n_jsonls + batch_size - 1) // batch_size # Length-adjustment: full set has avg ~5000 bp (mostly macrolide at 5000), # pilot biased to smaller drug classes which may have shorter sequences full_avg_seqlen = 5000 # known cap length_factor = full_avg_seqlen / max(pilot_avg_seqlen, 1) inference_compute_s = full_records * mean_region * length_factor commit_compute_s = full_batches * mean_commit cold_start_s = mean_load * min(n_workers, full_batches) total_compute_s = inference_compute_s + commit_compute_s + cold_start_s wall_proj = total_compute_s / min(n_workers, full_batches) cost = (total_compute_s / 3600) * h100_rate output_size_gb = (total_mb / max(n_done, 1)) * full_records * length_factor / 1024 print(f"\n PROJECTION ({full_records} records, {n_workers}× H100, batch_size={batch_size}, " f"H100=${h100_rate:.2f}/hr, length_factor={length_factor:.2f}):") print(f" inference compute: {inference_compute_s:7.0f} s") print(f" commit compute: {commit_compute_s:7.0f} s") print(f" cold-start total: {cold_start_s:7.0f} s") print(f" estimated wall clock: {wall_proj/60:5.1f} min") print(f" estimated cost: ${cost:.2f}") print(f" estimated total size: {output_size_gb:.1f} GB") @app.local_entrypoint() def run_syngenome_lean(batch_size: int = 2): """Upload SynGenome JSONLs + run lean embed.""" import os, time base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome" if not os.path.isdir(base): raise FileNotFoundError(f"SynGenome JSONLs not found at {base}; run scripts/sample_syngenome_amr.py first.") jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) print(f"[run-syngenome] uploading {len(jsonls)} JSONLs ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for fname in jsonls: batch.put_file(os.path.join(base, fname), f"syngenome/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") print(f"\n[run-syngenome] orchestrator (batch_size={batch_size})") r = orchestrate_syngenome_lean.remote(batch_size=batch_size) print(f"\n=== DONE ===") print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") print(f" total size: {r['total_mb']/1024:.1f} GB") print(f" mean per-region: {r['mean_region_s']:.2f} s") # Small (qual sample) layer-26 slice + HF push — separate volume from VFDB embeddings_l26_small_vol = modal.Volume.from_name( "mgnify-embeddings-l26-small", create_if_missing=True, ) @app.function( image=modal.Image.debian_slim().pip_install("numpy"), cpu=2, volumes={ "/in": embeddings_lean_vol, "/out": embeddings_l26_small_vol, }, timeout=3600, max_containers=8, ) def slice_l26_small_batch(rel_paths: list[str]) -> dict: import os import numpy as np n_done = n_skipped = n_errors = 0 total_mb = 0.0 for rel in rel_paths: in_path = f"/in/{rel}" out_path = f"/out/{rel}" if os.path.exists(out_path): n_skipped += 1 continue if not os.path.exists(in_path): n_errors += 1 continue try: with np.load(in_path, allow_pickle=False) as d: stack = d["per_token_layer_activations_bf16"] layer_indices = list(int(x) for x in d["layer_indices"]) pos = layer_indices.index(26) l26 = stack[pos].copy() passthrough = { "seq_len": d["seq_len"], "hidden_size": d["hidden_size"], "model_name": d["model_name"], "metadata_json": d["metadata_json"], } os.makedirs(os.path.dirname(out_path), exist_ok=True) np.savez( out_path, layer26_activations_bf16=l26, layer26_dtype="bfloat16", source_layer_index=np.int32(26), source_layer_name="blocks-26", **passthrough, ) total_mb += os.path.getsize(out_path) / 1e6 n_done += 1 except Exception as e: print(f" ERROR on {rel}: {e}") n_errors += 1 embeddings_l26_small_vol.commit() return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb} @app.function( image=modal.Image.debian_slim().pip_install("modal"), cpu=1, volumes={"/in": embeddings_lean_vol}, timeout=86400, ) def orchestrate_l26_small_slice(batch_size: int = 100) -> dict: import os paths = [] for root, _, files in os.walk("/in/small"): for fname in files: if fname.endswith(".npz"): rel = os.path.relpath(os.path.join(root, fname), "/in") paths.append(rel) paths.sort() batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] print(f"[orchestrator-l26-small] {len(paths)} npz → {len(batches)} batches") n_done = n_skipped = n_errors = 0 total_mb_out = 0.0 for i, r in enumerate(slice_l26_small_batch.map(batches, return_exceptions=True)): if isinstance(r, Exception): print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") continue n_done += r.get("n_done", 0) n_skipped += r.get("n_skipped", 0) n_errors += r.get("n_errors", 0) total_mb_out += r.get("total_mb", 0.0) print(f" done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB") return {"files_total": len(paths), "n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb_out": total_mb_out} @app.function( image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), cpu=4, volumes={"/vol": embeddings_l26_small_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=21600, ) def upload_l26_small_to_hf(repo_name: str = "mgnify-evo2-l26-small-qual", private: bool = False) -> dict: import os, time from huggingface_hub import HfApi, login token = None for k, v in os.environ.items(): if k.startswith("hf_") and len(k) > 30: token = k; break if v.startswith("hf_") and len(v) > 30: token = v; break if not token: token = os.environ.get("HF_TOKEN") login(token=token) api = HfApi() user = api.whoami()["name"] full_repo = f"{user}/{repo_name}" api.create_repo(full_repo, repo_type="dataset", private=private, exist_ok=True) print(f"[hf-push-small] uploading /vol → {full_repo} (private={private})") t0 = time.time() api.upload_large_folder(folder_path="/vol", repo_id=full_repo, repo_type="dataset") elapsed = time.time() - t0 n_files = bytes_total = 0 for root, _, files in os.walk("/vol"): for fname in files: if fname.endswith(".npz"): n_files += 1 bytes_total += os.path.getsize(os.path.join(root, fname)) return {"repo_url": f"https://huggingface.co/datasets/{full_repo}", "n_files": n_files, "bytes_total": bytes_total, "elapsed_s": elapsed, "private": private} @app.local_entrypoint() def push_l26_small(repo_name: str = "mgnify-evo2-l26-small-qual", private: bool = False, batch_size: int = 100): """Slice /embeddings_lean/small/ to layer-26 and push to HF.""" print("[1/2] slicing layer 26 from small/ ...") s = orchestrate_l26_small_slice.remote(batch_size=batch_size) print(f" files: {s['files_total']}, done: {s['n_done']}, skipped: {s['n_skipped']}, errors: {s['n_errors']}") print(f" l26 size: {s['total_mb_out']/1024:.2f} GB") print("\n[2/2] pushing to HF ...") u = upload_l26_small_to_hf.remote(repo_name=repo_name, private=private) print(f"\n=== DONE ===") print(f" repo: {u['repo_url']}") print(f" files: {u['n_files']}") print(f" size: {u['bytes_total']/1e9:.2f} GB") print(f" elapsed: {u['elapsed_s']:.0f} s") @app.local_entrypoint() def run_qual_lean(batch_size: int = 8): """Upload qual JSONLs + embed. Tiny job (~860 records, ~$0.30, ~2 min).""" import os, time base = "/home/ror25cal/MGnify/data/targeted_jsonl/qual" if not os.path.isdir(base): raise FileNotFoundError(f"Qual JSONLs not found at {base}; run scripts/sample_qual_jsonl.py first.") jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) print(f"[run-qual] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl ...") t0 = time.time() with jsonl_vol.batch_upload(force=True) as batch: for fname in jsonls: batch.put_file(os.path.join(base, fname), f"qual/{fname}") print(f" uploaded in {time.time()-t0:.0f} s") rel_paths = [f"qual/{f}" for f in jsonls] batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)] print(f"\n[run-qual] {len(rel_paths)} JSONLs in {len(batches)} batches of up to {batch_size}") t0 = time.time() results = list(embed_qual_lean.map(batches, return_exceptions=True)) wall = time.time() - t0 ok = [r for r in results if not isinstance(r, Exception)] n_done = sum(r["n_done"] for r in ok) n_skipped = sum(r["n_skipped"] for r in ok) total_mb = sum(r["total_mb"] for r in ok) print(f"\n=== QUAL DONE ===") print(f" records embedded: {n_done} (skipped {n_skipped})") print(f" output size: {total_mb:.0f} MB") print(f" wall clock: {wall:.0f} s") @app.local_entrypoint() def push_l26_vfdb(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", private: bool = False, batch_size: int = 100): """Slice layer-26 from /embeddings_lean/vfdb/ then push to HF Datasets.""" print("[1/2] slicing layer 26 from VFDB lean embeddings ...") s = orchestrate_l26_vfdb_slice.remote(batch_size=batch_size) print(f"\n files total: {s['files_total']}") print(f" done: {s['n_done']}") print(f" skipped: {s['n_skipped']}") print(f" errors: {s['n_errors']}") print(f" l26 vol size: {s['total_mb_out']/1024:.2f} GB") print("\n[2/2] pushing VFDB layer-26 volume to HF Datasets ...") u = upload_l26_vfdb_to_hf.remote(repo_name=repo_name, private=private) print(f"\n=== UPLOADED ===") print(f" repo: {u['repo_url']}") print(f" files: {u['n_files']}") print(f" size: {u['bytes_total']/1e9:.2f} GB") print(f" elapsed: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)")