| """ |
| Modal scaffold for self-hosted Evo 2-7B inference. |
| |
| Uses Arc Institute's official Dockerfile (from the ArcInstitute/evo2 repo), |
| cached HF weights in a Modal Volume, and the existing "huggingface" Modal |
| Secret (same one used by ~/AIMO3-TIR/compute/modal/). |
| |
| NO dependency on NVIDIA NIM / NGC / BioNeMo — weights come from HF, the |
| container is built from Arc's Dockerfile. |
| |
| ONE-TIME SETUP |
| 1. Authenticate Modal (already done): modal token new |
| 2. HF secret (already done via AIMO3): modal secret create huggingface HF_TOKEN=... |
| 3. Clone Arc's repo locally: git clone https://github.com/ArcInstitute/evo2.git ~/evo2_repo |
| 4. Upload MGnify data (optional, one-time): modal volume create mgnify-data |
| modal volume put mgnify-data /home/ror25cal/MGnify/data/ / |
| |
| USAGE |
| modal run modal/evo2_inference.py::main --seq "ACGT..." --layer 26 |
| """ |
| import os |
| from pathlib import Path |
|
|
| import modal |
|
|
| |
| APP_NAME = "mgnify-evo2-7b" |
| VOL_WEIGHTS = "evo2-7b-weights" |
| VOL_DATA = "mgnify-data" |
| TARGET_LAYER = "blocks.26.mlp.l3" |
|
|
| |
| |
| |
| |
| image = ( |
| modal.Image.from_registry( |
| "nvcr.io/nvidia/pytorch:25.04-py3", |
| add_python=None, |
| ) |
| .apt_install("git", "python3-pip", "python3-tomli") |
| .pip_install("evo2") |
| ) |
|
|
| |
| weights_vol = modal.Volume.from_name(VOL_WEIGHTS, create_if_missing=True) |
| data_vol = modal.Volume.from_name(VOL_DATA, create_if_missing=True) |
|
|
| app = modal.App(APP_NAME) |
|
|
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/data": data_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=3600, |
| ) |
| def embed(sequences: list[tuple[str, str]], layers: list[str] | None = None) -> dict: |
| """ |
| Run Evo 2-7B forward pass on a batch of (name, sequence) pairs and return |
| per-layer embeddings. First call downloads weights into the HF-cache Volume; |
| subsequent calls skip the download. |
| |
| sequences: list of (name, DNA_string) |
| layers: list of layer names, e.g. ["blocks.26.mlp.l3"]. Default: [TARGET_LAYER] |
| returns: {name: {layer: np.ndarray of shape [seq_len, hidden_dim]}} |
| """ |
| import numpy as np |
| import torch |
| from evo2 import Evo2 |
|
|
| layers = layers or [TARGET_LAYER] |
| model = Evo2("evo2_7b") |
|
|
| out = {} |
| for name, seq in sequences: |
| |
| input_ids = torch.tensor( |
| model.tokenizer.tokenize(seq), dtype=torch.int, |
| ).unsqueeze(0).to("cuda:0") |
| _, embeddings = model(input_ids, return_embeddings=True, layer_names=layers) |
| out[name] = {lyr: np.asarray(embeddings[lyr].squeeze(0).float().cpu()) for lyr in layers} |
| return out |
|
|
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={"/root/.cache/huggingface": weights_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=1800, |
| ) |
| def embed_and_sae(sequence: str, topk: int = 64) -> dict: |
| """ |
| Evo 2-7B-262k forward (hook at whole block-26 output) + Goodfire BatchTopK SAE. |
| Follows Arc Institute's reference notebook: |
| notebooks/sparse_autoencoder/sparse_autoencoder.ipynb |
| Key differences from teammate's simple-ReLU SAE class: |
| - model = evo2_7b_262k (the long-context variant Goodfire's SAE was trained on) |
| - layer name = 'blocks-26' (whole block output, not blocks.26.mlp.l3) |
| - BatchTopK=64 applied at encode — otherwise features are 4-5x too dense |
| """ |
| import numpy as np |
| import torch |
| from evo2 import Evo2 |
| from huggingface_hub import hf_hub_download |
|
|
| SAE_LAYER = "blocks-26" |
| D_HIDDEN = 4096 |
| D_SAE = D_HIDDEN * 8 |
| K = 64 |
|
|
| |
| evo2 = Evo2("evo2_7b_262k") |
| device = next(evo2.model.parameters()).device |
|
|
| |
| module_dict = {} |
| def recurse(m, prefix=""): |
| for name, child in m.named_children(): |
| module_dict[prefix + name] = child |
| recurse(child, prefix + name + "-") |
| recurse(evo2.model) |
| target_module = module_dict[SAE_LAYER] |
|
|
| cache = {} |
| def hook_fn(module, inp, out): |
| acts = out[0] if isinstance(out, tuple) else out |
| cache["acts"] = acts.detach() |
| handle = target_module.register_forward_hook(hook_fn) |
|
|
| try: |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(sequence), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| finally: |
| handle.remove() |
|
|
| acts = cache["acts"][0] |
|
|
| |
| sae_path = hf_hub_download( |
| repo_id="Goodfire/Evo-2-Layer-26-Mixed", |
| filename="sae-layer26-mixed-expansion_8-k_64.pt", |
| ) |
| sae_sd = torch.load(sae_path, map_location=device, weights_only=True) |
| |
| sae_sd = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_sd.items()} |
| W = sae_sd["W"].to(device=device, dtype=acts.dtype) |
| b_enc = sae_sd["b_enc"].to(device=device, dtype=acts.dtype) |
|
|
| |
| pre = torch.relu(acts @ W + b_enc) |
| flat = pre.flatten() |
| numel = K * pre.shape[0] |
| topk_res = torch.topk(flat, numel, dim=-1) |
| latents_flat = torch.zeros_like(flat).scatter(-1, topk_res.indices, topk_res.values) |
| latents = latents_flat.reshape(pre.shape) |
|
|
| |
| topk_vals_per_pos, topk_idx_per_pos = latents.topk(topk, dim=1) |
| active_per_pos = (latents > 0).sum(dim=1) |
|
|
| return { |
| "seq_len": int(acts.shape[0]), |
| "d_model": int(acts.shape[1]), |
| "d_sae": int(latents.shape[1]), |
| "model_name": "evo2_7b_262k", |
| "sae_layer": SAE_LAYER, |
| "topk_k": K, |
| "topk_values": topk_vals_per_pos.float().cpu().numpy().astype(np.float32).tolist(), |
| "topk_indices": topk_idx_per_pos.cpu().numpy().astype(np.int32).tolist(), |
| "active_features_per_position": active_per_pos.cpu().numpy().astype(np.int32).tolist(), |
| "activation_l2_norm": torch.linalg.norm(acts.float(), dim=1).cpu().numpy().astype(np.float32).tolist(), |
| } |
|
|
|
|
| |
| embeddings_vol = modal.Volume.from_name("mgnify-embeddings", create_if_missing=True) |
| embeddings_targeted_vol = modal.Volume.from_name("mgnify-embeddings-targeted", create_if_missing=True) |
| jsonl_vol = modal.Volume.from_name("mgnify-targeted-jsonl", create_if_missing=True) |
|
|
|
|
| def _get_models(): |
| """Module-level lazy cache. Caches Evo2 + SAE weights + module_dict, but NOT hooks |
| (hooks are re-registered per call to avoid stale-state OOM in container reuse).""" |
| try: |
| if _CACHED_EVO2 is not None: |
| return _CACHED_EVO2, _CACHED_SAE_W, _CACHED_SAE_BENC, _CACHED_DEVICE, _CACHED_MODULE_DICT |
| except NameError: |
| pass |
|
|
| import torch |
| from evo2 import Evo2 |
| from huggingface_hub import hf_hub_download |
|
|
| print("[container] loading Evo2 7B-262k (once per container)") |
| evo2 = Evo2("evo2_7b_262k") |
| device = next(evo2.model.parameters()).device |
|
|
| module_dict = {} |
| def recurse(m, prefix=""): |
| for n, c in m.named_children(): |
| module_dict[prefix + n] = c |
| recurse(c, prefix + n + "-") |
| recurse(evo2.model) |
|
|
| sae_path = hf_hub_download( |
| repo_id="Goodfire/Evo-2-Layer-26-Mixed", |
| filename="sae-layer26-mixed-expansion_8-k_64.pt", |
| ) |
| sae_sd = torch.load(sae_path, map_location=device, weights_only=True) |
| sae_sd = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_sd.items()} |
| W_sae = sae_sd["W"].to(device=device).to(torch.bfloat16) |
| b_enc = sae_sd["b_enc"].to(device=device).to(torch.bfloat16) |
|
|
| globals()["_CACHED_EVO2"] = evo2 |
| globals()["_CACHED_SAE_W"] = W_sae |
| globals()["_CACHED_SAE_BENC"] = b_enc |
| globals()["_CACHED_DEVICE"] = device |
| globals()["_CACHED_MODULE_DICT"] = module_dict |
| return evo2, W_sae, b_enc, device, module_dict |
|
|
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings": embeddings_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=3600, |
| ) |
| def embed_full( |
| mag_id: str, |
| contig_id: str, |
| sequence: str, |
| pool_size: int = 1000, |
| chunk_size: int = 64000, |
| overlap: int = 0, |
| ) -> dict: |
| """ |
| Process one contig of any length, chunking as needed. |
| Saves one /embeddings/{mag_id}/{contig_id}_{N}.npz per chunk. |
| Reuses Evo2/SAE models across calls within the same Modal container. |
| """ |
| import numpy as np |
| import torch |
| import os |
|
|
| K = 64 |
| layer_names = [f"blocks-{i}" for i in range(32)] |
| evo2, W_sae, b_enc, device, module_dict = _get_models() |
|
|
| |
| cache: dict = {} |
| def make_hook(name): |
| def hook(module, inp, out): |
| cache[name] = (out[0] if isinstance(out, tuple) else out).detach() |
| return hook |
| handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] |
|
|
| contig_len = len(sequence) |
| step = chunk_size - overlap |
| if step <= 0: |
| raise ValueError(f"overlap ({overlap}) must be < chunk_size ({chunk_size})") |
|
|
| |
| starts = list(range(0, contig_len, step)) |
| chunks = [(s, min(s + chunk_size, contig_len)) for s in starts if s < contig_len] |
| n_chunks = len(chunks) |
| print(f"[{mag_id}/{contig_id}] contig_len={contig_len:,}, {n_chunks} chunks of ~{chunk_size:,} bp (overlap={overlap})") |
|
|
| out_dir = f"/embeddings/{mag_id}" |
| os.makedirs(out_dir, exist_ok=True) |
| saved = [] |
|
|
| |
| for ci, (cstart, cend) in list(enumerate(chunks)): |
| existing = f"{out_dir}/{contig_id}_{ci}.npz" |
| if os.path.exists(existing): |
| saved.append({"path": existing, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": os.path.getsize(existing)/1e6, "skipped": True}) |
| if len(saved) == len(chunks): |
| print(f"[{mag_id}/{contig_id}] all {n_chunks} chunks already on volume, skipping") |
| return {"mag_id": mag_id, "contig_id": contig_id, "contig_len": contig_len, "n_chunks": n_chunks, "chunks": saved, "total_size_mb": sum(s['size_mb'] for s in saved), "all_cached": True} |
| saved = [s for s in saved if False] |
|
|
| try: |
| for ci, (cstart, cend) in enumerate(chunks): |
| existing_path = f"{out_dir}/{contig_id}_{ci}.npz" |
| if os.path.exists(existing_path): |
| saved.append({"path": existing_path, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": os.path.getsize(existing_path)/1e6, "cached": True}) |
| continue |
| chunk_seq = sequence[cstart:cend] |
| chunk_len = cend - cstart |
|
|
| |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(chunk_seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache["blocks-0"].shape[1] |
| print(f" [{ci+1}/{n_chunks}] bp {cstart:,}-{cend:,} seq_len={seq_len}") |
|
|
| |
| acts26 = cache["blocks-26"][0] |
| pre = torch.relu(acts26 @ W_sae + b_enc) |
| flat = pre.flatten() |
| numel = K * pre.shape[0] |
| tk = torch.topk(flat, numel, dim=-1) |
| sparse_flat = torch.zeros_like(flat).scatter(-1, tk.indices, tk.values) |
| latents = sparse_flat.reshape(pre.shape) |
| top_v, top_i = latents.topk(K, dim=1) |
| top_i_cpu = top_i.cpu().numpy().astype(np.int32) |
| top_v_cpu = top_v.float().cpu().numpy().astype(np.float16) |
| |
| del acts26, pre, flat, sparse_flat, latents, top_v, top_i, tk |
| torch.cuda.empty_cache() |
|
|
| |
| n_full = seq_len // pool_size |
| n_windows = n_full + (1 if seq_len > n_full * pool_size else 0) |
| |
| |
| layer_means = torch.zeros(32, n_windows, evo2.model.config.hidden_size, dtype=torch.bfloat16) |
| for i, ln in enumerate(layer_names): |
| acts = cache[ln][0] |
| full = acts[:n_full * pool_size].view(n_full, pool_size, -1).float().mean(dim=1) if n_full > 0 else acts.new_empty(0, acts.shape[-1]) |
| if seq_len > n_full * pool_size: |
| tail = acts[n_full * pool_size:].float().mean(dim=0, keepdim=True) |
| pooled = torch.cat([full, tail], dim=0) if n_full > 0 else tail |
| else: |
| pooled = full |
| layer_means[i] = pooled.to(torch.bfloat16).cpu() |
| |
| del acts, full, pooled |
| del cache[ln] |
| torch.cuda.empty_cache() |
| |
| layer_means_uint16 = layer_means.view(torch.uint16).numpy() |
|
|
| |
| out_path = f"{out_dir}/{contig_id}_{ci}.npz" |
| np.savez_compressed( |
| out_path, |
| layer_means_bf16=layer_means_uint16, |
| layer_means_dtype="bfloat16", |
| layer_names=np.array(layer_names), |
| pool_size=np.int32(pool_size), |
| sae_topk_indices=top_i_cpu, |
| sae_topk_values=top_v_cpu, |
| sae_layer="blocks-26", |
| seq_len=np.int32(seq_len), |
| chunk_start=np.int64(cstart), |
| chunk_end=np.int64(cend), |
| chunk_idx=np.int32(ci), |
| n_chunks=np.int32(n_chunks), |
| chunk_size=np.int32(chunk_size), |
| overlap=np.int32(overlap), |
| contig_id=contig_id, |
| contig_len=np.int64(contig_len), |
| mag_id=mag_id, |
| model_name="evo2_7b_262k", |
| ) |
| file_size = os.path.getsize(out_path) |
| print(f" saved {file_size/1e6:.1f} MB -> {out_path}") |
| saved.append({"path": out_path, "chunk_idx": ci, "bp_range": [cstart, cend], "size_mb": file_size/1e6}) |
|
|
| |
| del layer_means, layer_means_uint16, input_ids, top_i_cpu, top_v_cpu |
| torch.cuda.empty_cache() |
| finally: |
| for h in handles: h.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| embeddings_vol.commit() |
| print(f"[{mag_id}/{contig_id}] all {n_chunks} chunks saved, committed to volume") |
| return { |
| "mag_id": mag_id, "contig_id": contig_id, |
| "contig_len": contig_len, "n_chunks": n_chunks, |
| "chunks": saved, |
| "total_size_mb": sum(s["size_mb"] for s in saved), |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def full_mag_test( |
| mag_dir: str = "/home/ror25cal/MGnify/data/chicken-gut/species_catalogue/MGYG0003076/MGYG000307601/genome", |
| mag_id: str = "MGYG000307601", |
| contig_id: str = "MGYG000307601_21", |
| chunk_size: int = 64000, |
| overlap: int = 0, |
| ): |
| """Process one full contig, chunked. Saves /embeddings/{mag_id}/{contig_id}_{N}.npz per chunk.""" |
| from pathlib import Path |
| fna = Path(f"{mag_dir}/{mag_id}.fna").read_text() |
| cur_name = None; seq_parts = [] |
| records = {} |
| for line in fna.splitlines(): |
| if line.startswith(">"): |
| if cur_name: records[cur_name] = "".join(seq_parts) |
| cur_name = line[1:].split()[0]; seq_parts = [] |
| else: |
| seq_parts.append(line.strip()) |
| if cur_name: records[cur_name] = "".join(seq_parts) |
|
|
| seq = records[contig_id] |
| n_expected = (len(seq) + chunk_size - overlap - 1) // (chunk_size - overlap) |
| print(f"[{mag_id}/{contig_id}] {len(seq):,} bp → ~{n_expected} chunks @ {chunk_size:,} bp (overlap={overlap})") |
|
|
| result = embed_full.remote(mag_id, contig_id, seq, pool_size=1000, chunk_size=chunk_size, overlap=overlap) |
| print(f"\n=== RESULT ===") |
| print(f" contig_len: {result['contig_len']:,}") |
| print(f" n_chunks: {result['n_chunks']}") |
| print(f" total_size: {result['total_size_mb']:.1f} MB") |
| print(f" chunks:") |
| for c in result["chunks"]: |
| print(f" [{c['chunk_idx']}] bp {c['bp_range'][0]:,}-{c['bp_range'][1]:,} {c['size_mb']:.1f} MB → {c['path']}") |
|
|
|
|
| @app.local_entrypoint() |
| def run_top50_skin_amr( |
| csv_path: str = "/home/ror25cal/MGnify/modal/top50_skin_amr.csv", |
| skin_dir: str = "/home/ror25cal/MGnify/data/human-skin/species_catalogue", |
| chunk_size: int = 64000, |
| overlap: int = 0, |
| min_contig_len: int = 5000, |
| ): |
| """ |
| Process the top-50 human-skin MAGs (sorted by AMR-per-Mb density) end-to-end. |
| One Modal call per contig, parallelism via .map() (Modal default ~10x). |
| Container reuse + module-level model cache → model loads ~once per worker. |
| |
| Output: /embeddings/{mag_id}/{contig_id}_{chunk_idx}.npz on the mgnify-embeddings Volume. |
| """ |
| import csv |
| from pathlib import Path |
|
|
| |
| with open(csv_path) as f: |
| reader = csv.DictReader(f) |
| mag_rows = list(reader) |
| mag_ids = [r["mag_id"] for r in mag_rows] |
| print(f"loaded {len(mag_ids)} MAGs from {csv_path}") |
|
|
| |
| work = [] |
| total_bp = 0 |
| for mag_id in mag_ids: |
| prefix = mag_id[:11] |
| fna_path = Path(skin_dir) / prefix / mag_id / "genome" / f"{mag_id}.fna" |
| if not fna_path.exists(): |
| print(f" skipping {mag_id} — fna not found at {fna_path}") |
| continue |
| cur = None; parts = [] |
| records = {} |
| for line in fna_path.read_text().splitlines(): |
| if line.startswith(">"): |
| if cur: records[cur] = "".join(parts) |
| cur = line[1:].split()[0]; parts = [] |
| else: parts.append(line.strip()) |
| if cur: records[cur] = "".join(parts) |
| for cid, seq in records.items(): |
| if len(seq) < min_contig_len: |
| continue |
| work.append((mag_id, cid, seq, 1000, chunk_size, overlap)) |
| total_bp += len(seq) |
|
|
| n_chunks_estimate = sum((len(w[2]) + chunk_size - 1) // chunk_size for w in work) |
| |
| cost_estimate = n_chunks_estimate * 0.025 |
| print(f"\nwork queue: {len(work)} contigs (≥ {min_contig_len:,} bp), {total_bp:,} total bp, ~{n_chunks_estimate} chunks") |
| print(f"cost estimate: ~${cost_estimate:.0f}") |
|
|
| |
| print(f"\nsubmitting to Modal...\n") |
| n_done = 0; n_chunks_done = 0; total_mb = 0.0 |
| for result in embed_full.starmap(work, return_exceptions=True): |
| n_done += 1 |
| if isinstance(result, Exception): |
| print(f" [{n_done}/{len(work)}] ERROR: {result}") |
| continue |
| n_chunks_done += result.get("n_chunks", 0) |
| total_mb += result.get("total_size_mb", 0.0) |
| cached_tag = " (all cached)" if result.get("all_cached") else "" |
| print(f" [{n_done}/{len(work)}] {result['mag_id']}/{result['contig_id']}: {result['n_chunks']} chunks, {result['total_size_mb']:.1f} MB{cached_tag}") |
|
|
| print(f"\n=== DONE ===") |
| print(f" contigs processed: {n_done}/{len(work)}") |
| print(f" total chunks: {n_chunks_done}") |
| print(f" total volume size: {total_mb:.1f} MB") |
|
|
|
|
| @app.local_entrypoint() |
| def main(seq: str = "ACGT" * 100, layer: str = TARGET_LAYER): |
| """Smoke-test: single short sequence, one layer (just the Evo2 embed, no SAE).""" |
| import numpy as np |
|
|
| print(f"submitting 1 sequence of length {len(seq)} bp to {APP_NAME} @ {layer}") |
| results = embed.remote([("smoke-test", seq)], layers=[layer]) |
| for name, by_layer in results.items(): |
| for lyr, arr in by_layer.items(): |
| print(f" {name} / {lyr}: shape={arr.shape} |x|={np.abs(arr).mean():.3e}") |
|
|
|
|
| @app.local_entrypoint() |
| def crispr_test( |
| region_json: str = "/home/ror25cal/MGnify/modal/crispr_test_region.json", |
| out_path: str = "/home/ror25cal/MGnify/modal/crispr_test_result.json", |
| ): |
| """Run the CRISPR-region sanity test: Evo2 + Goodfire SAE, save result locally.""" |
| import json |
| import numpy as np |
|
|
| region = json.loads(open(region_json).read()) |
| seq = region["sequence"] |
| labels = np.array(region["labels"], dtype=np.int8) |
|
|
| print(f"region: {region['mag']} / {region['contig']} " |
| f"bp {region['region_start']}-{region['region_end']} len={len(seq)}") |
| print(f"label distribution: " |
| f"{dict(zip(*np.unique(labels, return_counts=True)))} " |
| f"(0=bg, 1=CRISPR, 2=DR, 3=spacer, 4=flank)") |
|
|
| result = embed_and_sae.remote(seq, topk=64) |
| print(f"\ngot back:") |
| print(f" seq_len: {result['seq_len']}") |
| print(f" d_model (Evo2 hidden): {result['d_model']}") |
| print(f" d_sae (Goodfire dict): {result['d_sae']}") |
| active = np.array(result["active_features_per_position"]) |
| print(f" active features/pos: median={int(np.median(active))} max={int(active.max())}") |
|
|
| |
| topk_vals = np.array(result["topk_values"]) |
| crispr_mask = labels > 0 |
| bg_mask = labels == 0 |
| |
| n = min(len(labels), result["seq_len"]) |
| labels = labels[:n]; crispr_mask = crispr_mask[:n]; bg_mask = bg_mask[:n] |
| top1 = topk_vals[:n, 0] |
| if crispr_mask.any() and bg_mask.any(): |
| print(f"\n top-1 SAE activation — CRISPR positions: mean={top1[crispr_mask].mean():.3f} N={crispr_mask.sum()}") |
| print(f" top-1 SAE activation — background: mean={top1[bg_mask].mean():.3f} N={bg_mask.sum()}") |
| print(f" ratio (CRISPR / bg): {top1[crispr_mask].mean() / max(top1[bg_mask].mean(), 1e-9):.2f}x") |
|
|
| |
| result["labels"] = labels.tolist() |
| result["region_meta"] = {k: v for k, v in region.items() if k not in ("sequence", "labels")} |
| open(out_path, "w").write(json.dumps(result)) |
| print(f"\nsaved result -> {out_path}") |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings_targeted": embeddings_targeted_vol, |
| "/jsonl": jsonl_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=3600, |
| max_containers=16, |
| ) |
| def embed_targeted_jsonl(jsonl_rel_path: str) -> dict: |
| """ |
| Process every record in one JSONL file (one MAG × one label). |
| Saves /embeddings_targeted/{label}/{mag_id}/{region_id}.npz containing |
| per-token activations across all 32 blocks (bf16 stored as uint16). |
| |
| jsonl_rel_path: path inside the jsonl volume, e.g. "full/AMR/MGYG000516287.jsonl" |
| """ |
| import json |
| import os |
| import numpy as np |
| import torch |
|
|
| src_path = f"/jsonl/{jsonl_rel_path}" |
| if not os.path.exists(src_path): |
| return {"path": jsonl_rel_path, "error": "missing", "n_done": 0, "n_skipped": 0, "total_mb": 0.0} |
|
|
| with open(src_path) as f: |
| records = [json.loads(line) for line in f if line.strip()] |
| if not records: |
| return {"path": jsonl_rel_path, "n_done": 0, "n_skipped": 0, "total_mb": 0.0} |
|
|
| evo2, _, _, device, module_dict = _get_models() |
| layer_names = [f"blocks-{i}" for i in range(32)] |
|
|
| cache: dict = {} |
| def make_hook(name): |
| def hook(module, inp, out): |
| cache[name] = (out[0] if isinstance(out, tuple) else out).detach() |
| return hook |
| handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] |
|
|
| n_done = 0 |
| n_skipped = 0 |
| total_mb = 0.0 |
| try: |
| for rec in records: |
| label_folder = rec["label"] if rec["is_positive"] else "MISC" |
| mag_id = rec["mag_id"] |
| region_id = rec["region_id"] |
| out_dir = f"/embeddings_targeted/{label_folder}/{mag_id}" |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = f"{out_dir}/{region_id}.npz" |
|
|
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
|
|
| seq = rec["sequence"] |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache["blocks-0"].shape[1] |
| hidden = evo2.model.config.hidden_size |
|
|
| |
| stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) |
| for i, ln in enumerate(layer_names): |
| stack[i] = cache[ln][0].to(torch.bfloat16).cpu() |
| del cache[ln] |
| torch.cuda.empty_cache() |
| stack_uint16 = stack.view(torch.uint16).numpy() |
|
|
| meta = {k: v for k, v in rec.items() if k != "sequence"} |
| np.savez_compressed( |
| out_path, |
| per_token_layer_activations_bf16=stack_uint16, |
| per_token_layer_activations_dtype="bfloat16", |
| layer_names=np.array(layer_names), |
| seq_len=np.int32(seq_len), |
| hidden_size=np.int32(hidden), |
| model_name="evo2_7b_262k", |
| metadata_json=np.array(json.dumps(meta)), |
| ) |
| sz = os.path.getsize(out_path) |
| total_mb += sz / 1e6 |
| n_done += 1 |
| print(f" [{label_folder}/{mag_id}/{region_id}] seq_len={seq_len} saved {sz/1e6:.1f} MB") |
| del stack, stack_uint16, input_ids |
| torch.cuda.empty_cache() |
| finally: |
| for h in handles: |
| h.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| embeddings_targeted_vol.commit() |
| return {"path": jsonl_rel_path, "n_done": n_done, "n_skipped": n_skipped, "total_mb": total_mb} |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| volumes={"/jsonl": jsonl_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_targeted_full() -> dict: |
| """ |
| CPU-only orchestrator. Walks the JSONL volume, fans every per-MAG-per-label |
| file out to embed_targeted_jsonl via .map() and aggregates results. |
| Living on Modal means the run survives local-process exit (use --detach). |
| """ |
| import os |
| jsonl_paths = [] |
| for root, _, files in os.walk("/jsonl/full"): |
| for fname in files: |
| if fname.endswith(".jsonl"): |
| rel = os.path.relpath(os.path.join(root, fname), "/jsonl") |
| jsonl_paths.append(rel) |
| jsonl_paths.sort() |
| print(f"[orchestrator] found {len(jsonl_paths)} JSONL files to process") |
|
|
| n_total_done = 0 |
| n_total_skipped = 0 |
| total_mb = 0.0 |
| errors = 0 |
| for i, r in enumerate(embed_targeted_jsonl.map(jsonl_paths, return_exceptions=True)): |
| if isinstance(r, Exception): |
| errors += 1 |
| print(f" [{i+1}/{len(jsonl_paths)}] ERROR: {r}") |
| continue |
| n_total_done += r.get("n_done", 0) |
| n_total_skipped += r.get("n_skipped", 0) |
| total_mb += r.get("total_mb", 0.0) |
| if (i + 1) % 25 == 0 or (i + 1) == len(jsonl_paths): |
| print(f" [{i+1}/{len(jsonl_paths)}] running totals: done={n_total_done} skipped={n_total_skipped} errors={errors} {total_mb/1024:.1f} GB") |
|
|
| return { |
| "jsonls": len(jsonl_paths), |
| "regions_done": n_total_done, |
| "regions_skipped": n_total_skipped, |
| "errors": errors, |
| "total_mb": total_mb, |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def run_targeted_full(): |
| """ |
| Launch the full targeted-region embed run. Use `modal run --detach` so it |
| keeps running after the local process exits. |
| |
| modal run --detach modal/evo2_inference.py::run_targeted_full |
| """ |
| print("[local] submitting orchestrator to Modal — fan-out happens server-side") |
| result = orchestrate_targeted_full.remote() |
| print("\n=== DONE ===") |
| print(f" JSONL files: {result['jsonls']}") |
| print(f" regions saved: {result['regions_done']}") |
| print(f" regions skipped: {result['regions_skipped']} (already on volume)") |
| print(f" errors: {result['errors']}") |
| print(f" total volume: {result['total_mb']/1024:.1f} GB") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| embeddings_l26_vol = modal.Volume.from_name("mgnify-embeddings-l26", create_if_missing=True) |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("numpy"), |
| cpu=2, |
| volumes={ |
| "/in": embeddings_targeted_vol, |
| "/out": embeddings_l26_vol, |
| }, |
| timeout=3600, |
| max_containers=16, |
| ) |
| def extract_l26_batch(rel_paths: list[str]) -> dict: |
| """Slice layer 26 from each input npz; write a compact per-region npz to the l26 volume. |
| rel_paths: list like ['AMR/MGYG.../MGYG..._00123_AMR.npz', ...] relative to volume root.""" |
| import os |
| import numpy as np |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_errors = 0 |
| total_mb_in = 0.0 |
| total_mb_out = 0.0 |
|
|
| for rel in rel_paths: |
| in_path = f"/in/{rel}" |
| out_path = f"/out/{rel}" |
| if not os.path.exists(in_path): |
| n_errors += 1 |
| continue |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
|
|
| try: |
| with np.load(in_path, allow_pickle=False) as d: |
| stack = d["per_token_layer_activations_bf16"] |
| l26 = stack[26].copy() |
| passthrough = { |
| "layer_names": d["layer_names"], |
| "seq_len": d["seq_len"], |
| "hidden_size": d["hidden_size"], |
| "model_name": d["model_name"], |
| "metadata_json": d["metadata_json"], |
| } |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| np.savez_compressed( |
| out_path, |
| layer26_activations_bf16=l26, |
| layer26_dtype="bfloat16", |
| source_layer_index=np.int32(26), |
| source_layer_name="blocks-26", |
| **passthrough, |
| ) |
| total_mb_in += os.path.getsize(in_path) / 1e6 |
| total_mb_out += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| except Exception as e: |
| print(f" ERROR on {rel}: {e}") |
| n_errors += 1 |
|
|
| embeddings_l26_vol.commit() |
| return { |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "n_errors": n_errors, |
| "total_mb_in": total_mb_in, |
| "total_mb_out": total_mb_out, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/in": embeddings_targeted_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_l26_extract(batch_size: int = 50) -> dict: |
| """List every committed all-32-layer npz on /in, batch by N, fan out to extract_l26_batch.""" |
| import os |
| paths = [] |
| for root, _, files in os.walk("/in"): |
| for fname in files: |
| if fname.endswith(".npz"): |
| rel = os.path.relpath(os.path.join(root, fname), "/in") |
| paths.append(rel) |
| paths.sort() |
| print(f"[orchestrator] found {len(paths)} all-32-layer npz files to slice") |
|
|
| |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator] dispatching {len(batches)} batches of up to {batch_size}") |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_errors = 0 |
| total_mb_in = 0.0 |
| total_mb_out = 0.0 |
| for i, r in enumerate(extract_l26_batch.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| n_errors += r.get("n_errors", 0) |
| total_mb_in += r.get("total_mb_in", 0.0) |
| total_mb_out += r.get("total_mb_out", 0.0) |
| if (i + 1) % 5 == 0 or (i + 1) == len(batches): |
| print(f" [{i+1}/{len(batches)}] running totals: done={n_done} skipped={n_skipped} errors={n_errors} in={total_mb_in/1024:.1f} GB → out={total_mb_out/1024:.2f} GB") |
|
|
| return { |
| "regions_total": len(paths), |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "n_errors": n_errors, |
| "total_mb_in": total_mb_in, |
| "total_mb_out": total_mb_out, |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def run_l26_extract(): |
| """ |
| Slice layer-26 out of every committed all-32-layer npz on |
| mgnify-embeddings-targeted, write to mgnify-embeddings-l26. |
| |
| modal run --detach modal/evo2_inference.py::run_l26_extract |
| """ |
| print("[local] submitting layer-26 slicer orchestrator to Modal") |
| result = orchestrate_l26_extract.remote() |
| print("\n=== DONE ===") |
| print(f" source npz scanned: {result['regions_total']}") |
| print(f" newly sliced: {result['n_done']}") |
| print(f" skipped (already done): {result['n_skipped']}") |
| print(f" errors: {result['n_errors']}") |
| print(f" input bytes read: {result['total_mb_in']/1024:.2f} GB") |
| print(f" output bytes written: {result['total_mb_out']/1024:.2f} GB") |
| print(f" compression ratio: {result['total_mb_in']/max(result['total_mb_out'],1):.1f}x smaller") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim(), |
| cpu=4, |
| volumes={"/vol": embeddings_l26_vol}, |
| timeout=3600, |
| ) |
| def pack_l26_archive(out_name: str = "embeddings_l26.zip") -> dict: |
| """Walk /vol, write all .npz into a single zip archive (store mode — npz already gzipped) |
| at /vol/{out_name}. Returns size + file count.""" |
| import os |
| import zipfile |
| import time |
|
|
| out_path = f"/vol/{out_name}" |
| if os.path.exists(out_path): |
| os.remove(out_path) |
|
|
| n = 0 |
| bytes_in = 0 |
| t0 = time.time() |
| with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_STORED) as zf: |
| for root, _, files in os.walk("/vol"): |
| for fname in files: |
| if not fname.endswith(".npz"): |
| continue |
| full = os.path.join(root, fname) |
| arcname = os.path.relpath(full, "/vol") |
| zf.write(full, arcname) |
| bytes_in += os.path.getsize(full) |
| n += 1 |
| if n % 50 == 0: |
| print(f" packed {n} files, {bytes_in/1e9:.2f} GB raw...") |
|
|
| out_size = os.path.getsize(out_path) |
| embeddings_l26_vol.commit() |
| return { |
| "archive_path": out_path, |
| "n_files": n, |
| "bytes_raw": bytes_in, |
| "bytes_archive": out_size, |
| "elapsed_s": time.time() - t0, |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def pack_and_report(): |
| """Pack the l26 volume into a single zip on the volume itself, ready for one-shot download.""" |
| print("[local] packing layer-26 npz files into single zip on Modal volume...") |
| r = pack_l26_archive.remote() |
| print("\n=== PACKED ===") |
| print(f" archive path on volume: {r['archive_path']}") |
| print(f" files packed: {r['n_files']}") |
| print(f" raw size: {r['bytes_raw']/1e9:.2f} GB") |
| print(f" archive size: {r['bytes_archive']/1e9:.2f} GB") |
| print(f" elapsed (server-side): {r['elapsed_s']:.0f} s") |
| print(f"\nDownload locally with:") |
| print(f" modal volume get mgnify-embeddings-l26 /embeddings_l26.zip /home/ror25cal/MGnify/data/embeddings_l26.zip") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), |
| cpu=4, |
| volumes={"/vol": embeddings_l26_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=7200, |
| ) |
| def upload_l26_to_hf(repo_name: str = "mgnify-evo2-l26-amr-pilot", private: bool = True) -> dict: |
| """Push every .npz under /vol to a HF Dataset repo. Uses upload_large_folder for |
| parallel + resumable LFS uploads.""" |
| import os |
| import time |
| from huggingface_hub import HfApi |
|
|
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| if not token: |
| |
| |
| for k, v in os.environ.items(): |
| if k.startswith("hf_") and v == "HF_TOKEN": |
| token = k |
| break |
| if not token: |
| raise RuntimeError("HF_TOKEN env var missing — check the 'huggingface' Modal Secret") |
| api = HfApi(token=token) |
| me = api.whoami() |
| user = me.get("name") |
| repo_id = f"{user}/{repo_name}" |
| print(f"[hf] authenticated as: {user}") |
| print(f"[hf] target repo: {repo_id} (private={private})") |
|
|
| api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True) |
| print(f"[hf] repo ready: https://huggingface.co/datasets/{repo_id}") |
|
|
| |
| n_files = 0 |
| bytes_total = 0 |
| for root, _, files in os.walk("/vol"): |
| for f in files: |
| if f.endswith(".npz"): |
| n_files += 1 |
| bytes_total += os.path.getsize(os.path.join(root, f)) |
| print(f"[hf] uploading {n_files} files, {bytes_total/1e9:.2f} GB total") |
|
|
| t0 = time.time() |
| api.upload_large_folder( |
| repo_id=repo_id, |
| repo_type="dataset", |
| folder_path="/vol", |
| allow_patterns=["**/*.npz"], |
| print_report=True, |
| ) |
| elapsed = time.time() - t0 |
|
|
| return { |
| "repo_id": repo_id, |
| "repo_url": f"https://huggingface.co/datasets/{repo_id}", |
| "n_files": n_files, |
| "bytes_total": bytes_total, |
| "elapsed_s": elapsed, |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def push_l26_to_hf(repo_name: str = "mgnify-evo2-l26-amr-pilot", private: bool = True): |
| """Push the layer-26 volume directly to a HF Dataset repo (no local download). |
| |
| modal run modal/evo2_inference.py::push_l26_to_hf |
| modal run modal/evo2_inference.py::push_l26_to_hf --repo-name foo --no-private |
| """ |
| print(f"[local] launching HF push (repo={repo_name}, private={private})") |
| r = upload_l26_to_hf.remote(repo_name=repo_name, private=private) |
| print("\n=== UPLOADED ===") |
| print(f" repo: {r['repo_url']}") |
| print(f" files: {r['n_files']}") |
| print(f" size: {r['bytes_total']/1e9:.2f} GB") |
| print(f" elapsed: {r['elapsed_s']:.0f} s ({r['bytes_total']/1e6/max(r['elapsed_s'],1):.1f} MB/s)") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| embeddings_lean_vol = modal.Volume.from_name("mgnify-embeddings-lean", create_if_missing=True) |
| LEAN_LAYERS: list[int] = [14, 20, 24, 26, 28] |
|
|
|
|
| def _get_evo2_only(): |
| """Lighter than _get_models — skips SAE weight load. Cached at module level.""" |
| try: |
| if _CACHED_EVO2_LEAN is not None: |
| return _CACHED_EVO2_LEAN, _CACHED_DEVICE_LEAN, _CACHED_MODULE_DICT_LEAN |
| except NameError: |
| pass |
|
|
| from evo2 import Evo2 |
| print("[container] loading Evo2 7B-262k (no SAE)") |
| evo2 = Evo2("evo2_7b_262k") |
| device = next(evo2.model.parameters()).device |
|
|
| module_dict = {} |
| def recurse(m, prefix=""): |
| for n, c in m.named_children(): |
| module_dict[prefix + n] = c |
| recurse(c, prefix + n + "-") |
| recurse(evo2.model) |
|
|
| globals()["_CACHED_EVO2_LEAN"] = evo2 |
| globals()["_CACHED_DEVICE_LEAN"] = device |
| globals()["_CACHED_MODULE_DICT_LEAN"] = module_dict |
| return evo2, device, module_dict |
|
|
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings_lean": embeddings_lean_vol, |
| "/jsonl": jsonl_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=7200, |
| max_containers=16, |
| ) |
| def embed_targeted_lean(jsonl_rel_paths) -> dict: |
| """ |
| Process all records in a *batch* of JSONL files. Single volume.commit() at end. |
| Accepts either str (single JSONL, back-compat) or list[str] (batch). |
| Saves /embeddings_lean/{label}/{mag}/{region}.npz per region. |
| """ |
| import json |
| import os |
| import time |
| import numpy as np |
| import torch |
|
|
| if isinstance(jsonl_rel_paths, str): |
| jsonl_rel_paths = [jsonl_rel_paths] |
|
|
| t_load_start = time.time() |
| evo2, device, module_dict = _get_evo2_only() |
| t_load = time.time() - t_load_start |
| layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] |
|
|
| cache: dict = {} |
| def make_hook(name): |
| def hook(module, inp, out): |
| cache[name] = (out[0] if isinstance(out, tuple) else out).detach() |
| return hook |
| handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_missing_jsonl = 0 |
| total_mb = 0.0 |
| per_region_times: list[float] = [] |
| try: |
| for jsonl_rel in jsonl_rel_paths: |
| src_path = f"/jsonl/{jsonl_rel}" |
| if not os.path.exists(src_path): |
| n_missing_jsonl += 1 |
| continue |
| with open(src_path) as f: |
| records = [json.loads(line) for line in f if line.strip()] |
| if not records: |
| continue |
|
|
| for rec in records: |
| label_folder = rec["label"] if rec["is_positive"] else "MISC" |
| mag_id = rec["mag_id"] |
| region_id = rec["region_id"] |
| out_dir = f"/embeddings_lean/{label_folder}/{mag_id}" |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = f"{out_dir}/{region_id}.npz" |
|
|
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
|
|
| t_region = time.time() |
| seq = rec["sequence"] |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache[layer_names[0]].shape[1] |
| hidden = evo2.model.config.hidden_size |
|
|
| stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) |
| for i, ln in enumerate(layer_names): |
| stack[i] = cache[ln][0].to(torch.bfloat16).cpu() |
| del cache[ln] |
| torch.cuda.empty_cache() |
| stack_uint16 = stack.view(torch.uint16).numpy() |
|
|
| meta = {k: v for k, v in rec.items() if k != "sequence"} |
| np.savez( |
| out_path, |
| per_token_layer_activations_bf16=stack_uint16, |
| per_token_layer_activations_dtype="bfloat16", |
| layer_names=np.array(layer_names), |
| layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), |
| seq_len=np.int32(seq_len), |
| hidden_size=np.int32(hidden), |
| model_name="evo2_7b_262k", |
| metadata_json=np.array(json.dumps(meta)), |
| ) |
| sz = os.path.getsize(out_path) |
| total_mb += sz / 1e6 |
| n_done += 1 |
| per_region_times.append(time.time() - t_region) |
| del stack, stack_uint16, input_ids |
| torch.cuda.empty_cache() |
| finally: |
| for h in handles: |
| h.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| t_commit_start = time.time() |
| embeddings_lean_vol.commit() |
| t_commit = time.time() - t_commit_start |
|
|
| return { |
| "n_jsonls": len(jsonl_rel_paths), |
| "n_missing_jsonl": n_missing_jsonl, |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "total_mb": total_mb, |
| "model_load_s": t_load, |
| "commit_s": t_commit, |
| "per_region_s": per_region_times, |
| "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/jsonl": jsonl_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_lean_full(batch_size: int = 50) -> dict: |
| """Walks JSONL volume, batches paths by `batch_size`, fans out to embed_targeted_lean. |
| Batching keeps the per-call commit overhead amortized.""" |
| import os |
| jsonl_paths = [] |
| for root, _, files in os.walk("/jsonl/full"): |
| for fname in files: |
| if fname.endswith(".jsonl"): |
| rel = os.path.relpath(os.path.join(root, fname), "/jsonl") |
| jsonl_paths.append(rel) |
| jsonl_paths.sort() |
| batches = [jsonl_paths[i:i + batch_size] for i in range(0, len(jsonl_paths), batch_size)] |
| print(f"[orchestrator-lean] {len(jsonl_paths)} JSONLs → {len(batches)} batches of up to {batch_size}") |
|
|
| n_done = 0 |
| n_skipped = 0 |
| total_mb = 0.0 |
| errors = 0 |
| region_time_samples: list[float] = [] |
| commit_time_samples: list[float] = [] |
|
|
| for i, r in enumerate(embed_targeted_lean.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| errors += 1 |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| total_mb += r.get("total_mb", 0.0) |
| if r.get("per_region_s"): |
| region_time_samples.extend(r["per_region_s"]) |
| if r.get("commit_s") is not None: |
| commit_time_samples.append(r["commit_s"]) |
| if (i + 1) % 5 == 0 or (i + 1) == len(batches): |
| mean_t = sum(region_time_samples) / max(len(region_time_samples), 1) |
| mean_c = sum(commit_time_samples) / max(len(commit_time_samples), 1) |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " |
| f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f} mean_commit_s={mean_c:.2f}") |
|
|
| return { |
| "jsonls": len(jsonl_paths), |
| "batches": len(batches), |
| "regions_done": n_done, |
| "regions_skipped": n_skipped, |
| "errors": errors, |
| "total_mb": total_mb, |
| "mean_region_s": sum(region_time_samples) / max(len(region_time_samples), 1), |
| "mean_commit_s": sum(commit_time_samples) / max(len(commit_time_samples), 1), |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def pilot_lean_batched(n_jsonls: int = 100, batch_size: int = 50): |
| """Run the *batched* lean pipeline on N JSONLs to measure realistic commit overhead. |
| Then download one output file and verify schema.""" |
| import os |
| import time |
|
|
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/full" |
| paths = [] |
| for label in ["AMR", "MISC", "VIRULENCE", "STRESS"]: |
| d = os.path.join(base, label) |
| if os.path.isdir(d): |
| for fname in sorted(os.listdir(d)): |
| if fname.endswith(".jsonl"): |
| paths.append(f"full/{label}/{fname}") |
| paths = paths[:n_jsonls] |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[pilot-batched] {len(paths)} JSONLs in {len(batches)} batches of up to {batch_size}") |
|
|
| t0 = time.time() |
| results = list(embed_targeted_lean.map(batches, return_exceptions=True)) |
| wall = time.time() - t0 |
|
|
| ok = [r for r in results if not isinstance(r, Exception)] |
| n_done = sum(r["n_done"] for r in ok) |
| total_mb = sum(r["total_mb"] for r in ok) |
| all_region_times = [t for r in ok for t in r.get("per_region_s") or []] |
| commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None] |
| load_times = [r["model_load_s"] for r in ok] |
|
|
| mean_region = sum(all_region_times) / max(len(all_region_times), 1) |
| mean_commit = sum(commit_times) / max(len(commit_times), 1) |
| print(f"\n=== PILOT-BATCHED SUMMARY ===") |
| print(f" records: {n_done}") |
| print(f" total volume size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.1f} MB/record)") |
| print(f" wall clock (parallel): {wall:.0f} s across {len(batches)} batches × {len(results)} workers") |
| print(f" per-region inference: {mean_region:.2f} s avg (min {min(all_region_times):.2f}, max {max(all_region_times):.2f})") |
| print(f" per-batch commit: {mean_commit:.2f} s avg (min {min(commit_times):.2f}, max {max(commit_times):.2f})") |
| print(f" per-call model load: {sum(load_times)/max(len(load_times),1):.1f} s avg (cold start)") |
|
|
| |
| full_records = 5483 |
| full_n_jsonls = 2416 |
| n_workers = 16 |
| full_batches = (full_n_jsonls + batch_size - 1) // batch_size |
| inference_compute_s = full_records * mean_region |
| commit_compute_s = full_batches * mean_commit |
| cold_start_s = sum(load_times) / max(len(load_times), 1) * n_workers |
| total_compute_s = inference_compute_s + commit_compute_s + cold_start_s |
| wall_proj = total_compute_s / n_workers |
| cost = (total_compute_s / 3600) * 4.50 |
|
|
| print(f"\n PROJECTION (5483 records, {n_workers}× H100, batch_size={batch_size}):") |
| print(f" inference compute: {inference_compute_s:6.0f} s ({inference_compute_s/total_compute_s*100:.0f}%)") |
| print(f" commit compute: {commit_compute_s:6.0f} s ({commit_compute_s/total_compute_s*100:.0f}%)") |
| print(f" cold-start total: {cold_start_s:6.0f} s ({cold_start_s/total_compute_s*100:.0f}%)") |
| print(f" estimated wall clock: {wall_proj/60:.1f} min") |
| print(f" estimated cost: ${cost:.2f}") |
| print(f" estimated total size: {total_mb/max(n_done,1) * full_records / 1024:.1f} GB") |
|
|
|
|
| @app.local_entrypoint() |
| def run_lean_full(batch_size: int = 50): |
| """Launch the full lean run detached. |
| |
| modal run --detach modal/evo2_inference.py::run_lean_full |
| modal run --detach modal/evo2_inference.py::run_lean_full --batch-size 50 |
| """ |
| print(f"[local] submitting orchestrator (batch_size={batch_size}) to Modal") |
| r = orchestrate_lean_full.remote(batch_size=batch_size) |
| print("\n=== DONE ===") |
| print(f" JSONL files: {r['jsonls']} in {r['batches']} batches") |
| print(f" regions saved: {r['regions_done']}") |
| print(f" regions skipped: {r['regions_skipped']} (already on volume)") |
| print(f" errors: {r['errors']}") |
| print(f" total volume size: {r['total_mb']/1024:.1f} GB") |
| print(f" mean per-region: {r['mean_region_s']:.2f} s") |
| print(f" mean per-batch commit:{r['mean_commit_s']:.2f} s") |
|
|
|
|
| @app.local_entrypoint() |
| def pilot_lean(mag_id: str = "MGYG000516287"): |
| """Pilot the lean pipeline on a single MAG (across all labels). Reports empirical |
| timing + projects full-run cost. |
| """ |
| import time |
| pilot_jsonls = [f"full/{label}/{mag_id}.jsonl" for label in ["AMR", "VIRULENCE", "STRESS", "MISC"]] |
| print(f"[pilot] running on MAG {mag_id} across labels (skipping any missing)") |
|
|
| results = [] |
| t0 = time.time() |
| for path in pilot_jsonls: |
| try: |
| r = embed_targeted_lean.remote(path) |
| print(f" {path}: done={r['n_done']} skipped={r['n_skipped']} " |
| f"model_load={r.get('model_load_s', 0):.1f}s " |
| f"mean_region={r.get('mean_per_region_s') or 0:.2f}s " |
| f"size={r['total_mb']:.1f} MB") |
| results.append(r) |
| except Exception as e: |
| print(f" {path}: SKIPPED ({e})") |
| wall = time.time() - t0 |
|
|
| |
| all_region_times = [t for r in results for t in r.get("per_region_s") or []] |
| n_total = sum(r["n_done"] for r in results) |
| bytes_total = sum(r["total_mb"] for r in results) |
|
|
| print("\n=== PILOT SUMMARY ===") |
| print(f" records processed: {n_total}") |
| print(f" total volume size: {bytes_total:.1f} MB ({bytes_total/max(n_total,1):.1f} MB/record avg)") |
| print(f" wall clock (single worker): {wall:.0f} s") |
| if all_region_times: |
| mean_t = sum(all_region_times) / len(all_region_times) |
| print(f" per-region time: {mean_t:.2f} s avg " |
| f"(min {min(all_region_times):.2f}, max {max(all_region_times):.2f})") |
| |
| full_records = 5483 |
| full_compute_s = full_records * mean_t |
| n_workers = 16 |
| |
| cold_start_s = 60 |
| wall_proj = cold_start_s + full_compute_s / n_workers |
| h100_hours = (full_compute_s / 3600 + (cold_start_s * n_workers) / 3600) |
| cost_proj = h100_hours * 4.50 |
| print(f"\n PROJECTION (5483 regions, {n_workers}× H100):") |
| print(f" estimated wall clock: {wall_proj/60:.1f} min") |
| print(f" estimated cost: ${cost_proj:.2f}") |
| print(f" estimated total size: {bytes_total/max(n_total,1) * full_records / 1024:.1f} GB") |
|
|
|
|
| |
| |
| |
| |
|
|
| embeddings_l26_lean_vol = modal.Volume.from_name("mgnify-embeddings-l26-lean", create_if_missing=True) |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("numpy"), |
| cpu=2, |
| volumes={ |
| "/in": embeddings_lean_vol, |
| "/out": embeddings_l26_lean_vol, |
| }, |
| timeout=3600, |
| max_containers=16, |
| ) |
| def slice_l26_from_lean_batch(rel_paths: list[str]) -> dict: |
| """Slice layer 26 from each lean npz. Single commit per batch.""" |
| import os |
| import numpy as np |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_errors = 0 |
| total_mb = 0.0 |
| for rel in rel_paths: |
| in_path = f"/in/{rel}" |
| out_path = f"/out/{rel}" |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
| if not os.path.exists(in_path): |
| n_errors += 1 |
| continue |
| try: |
| with np.load(in_path, allow_pickle=False) as d: |
| stack = d["per_token_layer_activations_bf16"] |
| layer_indices = list(int(x) for x in d["layer_indices"]) |
| pos = layer_indices.index(26) |
| l26 = stack[pos].copy() |
| passthrough = { |
| "seq_len": d["seq_len"], |
| "hidden_size": d["hidden_size"], |
| "model_name": d["model_name"], |
| "metadata_json": d["metadata_json"], |
| } |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| np.savez( |
| out_path, |
| layer26_activations_bf16=l26, |
| layer26_dtype="bfloat16", |
| source_layer_index=np.int32(26), |
| source_layer_name="blocks-26", |
| **passthrough, |
| ) |
| total_mb += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| except Exception as e: |
| print(f" ERROR on {rel}: {e}") |
| n_errors += 1 |
| embeddings_l26_lean_vol.commit() |
| return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb} |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/in": embeddings_lean_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_l26_lean_slice(batch_size: int = 50) -> dict: |
| """List lean npz files, batch them, fan out to slicer workers.""" |
| import os |
| paths = [] |
| for root, _, files in os.walk("/in"): |
| for fname in files: |
| if fname.endswith(".npz"): |
| rel = os.path.relpath(os.path.join(root, fname), "/in") |
| paths.append(rel) |
| paths.sort() |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator-l26-slice] {len(paths)} lean npz → {len(batches)} batches") |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_errors = 0 |
| total_mb_in = 0.0 |
| total_mb_out = 0.0 |
| for i, r in enumerate(slice_l26_from_lean_batch.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| n_errors += r.get("n_errors", 0) |
| total_mb_out += r.get("total_mb", 0.0) |
| if (i + 1) % 5 == 0 or (i + 1) == len(batches): |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB") |
| return { |
| "files_total": len(paths), |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "n_errors": n_errors, |
| "total_mb_out": total_mb_out, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), |
| cpu=4, |
| volumes={"/vol": embeddings_l26_lean_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=21600, |
| ) |
| def upload_l26_lean_to_hf(repo_name: str = "mgnify-evo2-l26-full", private: bool = True) -> dict: |
| """Push the layer-26-lean volume to HF Datasets via upload_large_folder.""" |
| import os |
| import time |
| from huggingface_hub import HfApi |
|
|
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| if not token: |
| |
| for k, v in os.environ.items(): |
| if k.startswith("hf_") and v == "HF_TOKEN": |
| token = k |
| break |
| if not token: |
| raise RuntimeError("HF_TOKEN env var missing — check the 'huggingface' Modal Secret") |
|
|
| api = HfApi(token=token) |
| me = api.whoami() |
| user = me.get("name") |
| repo_id = f"{user}/{repo_name}" |
| print(f"[hf] authenticated as: {user}") |
| print(f"[hf] target repo: {repo_id} (private={private})") |
| api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True) |
| print(f"[hf] repo ready: https://huggingface.co/datasets/{repo_id}") |
|
|
| n_files = 0 |
| bytes_total = 0 |
| for root, _, files in os.walk("/vol"): |
| for f in files: |
| if f.endswith(".npz"): |
| n_files += 1 |
| bytes_total += os.path.getsize(os.path.join(root, f)) |
| print(f"[hf] uploading {n_files} files, {bytes_total/1e9:.2f} GB total") |
|
|
| t0 = time.time() |
| api.upload_large_folder( |
| repo_id=repo_id, |
| repo_type="dataset", |
| folder_path="/vol", |
| allow_patterns=["**/*.npz"], |
| print_report=True, |
| ) |
| elapsed = time.time() - t0 |
| return { |
| "repo_id": repo_id, |
| "repo_url": f"https://huggingface.co/datasets/{repo_id}", |
| "n_files": n_files, |
| "bytes_total": bytes_total, |
| "elapsed_s": elapsed, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=120, |
| ) |
| def set_hf_dataset_visibility(repo_name: str, private: bool) -> dict: |
| """Toggle visibility of an HF dataset. Used to flip private→public after upload.""" |
| import os |
| from huggingface_hub import HfApi |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| if not token: |
| for k, v in os.environ.items(): |
| if k.startswith("hf_") and v == "HF_TOKEN": |
| token = k |
| break |
| if not token: |
| raise RuntimeError("HF_TOKEN missing") |
| api = HfApi(token=token) |
| me = api.whoami() |
| repo_id = f"{me['name']}/{repo_name}" |
| |
| if hasattr(api, "update_repo_settings"): |
| api.update_repo_settings(repo_id=repo_id, repo_type="dataset", private=private) |
| else: |
| api.update_repo_visibility(repo_id=repo_id, repo_type="dataset", private=private) |
| info = api.repo_info(repo_id=repo_id, repo_type="dataset") |
| return {"repo_id": repo_id, "private": getattr(info, "private", None), "url": f"https://huggingface.co/datasets/{repo_id}"} |
|
|
|
|
| @app.local_entrypoint() |
| def make_l26_dataset_public(repo_name: str = "mgnify-evo2-l26-full"): |
| """Flip the layer-26 dataset to public.""" |
| r = set_hf_dataset_visibility.remote(repo_name=repo_name, private=False) |
| print(f" repo: {r['repo_id']}") |
| print(f" private: {r['private']}") |
| print(f" url: {r['url']}") |
|
|
|
|
| @app.local_entrypoint() |
| def push_l26_lean(repo_name: str = "mgnify-evo2-l26-full", batch_size: int = 50, private: bool = True): |
| """Slice layer 26 from the lean volume into its own volume, then upload to HF. |
| |
| modal run --detach modal/evo2_inference.py::push_l26_lean |
| modal run --detach modal/evo2_inference.py::push_l26_lean --no-private |
| """ |
| print("[1/2] slicing layer-26 from lean volume on Modal…") |
| s = orchestrate_l26_lean_slice.remote(batch_size=batch_size) |
| print(f" scanned {s['files_total']} lean npz") |
| print(f" newly sliced: {s['n_done']}") |
| print(f" skipped: {s['n_skipped']}") |
| print(f" errors: {s['n_errors']}") |
| print(f" l26 volume size: {s['total_mb_out']/1024:.2f} GB") |
|
|
| print("\n[2/2] pushing layer-26 volume to HF Datasets…") |
| u = upload_l26_lean_to_hf.remote(repo_name=repo_name, private=private) |
| print(f"\n=== UPLOADED ===") |
| print(f" repo: {u['repo_url']}") |
| print(f" files: {u['n_files']}") |
| print(f" size: {u['bytes_total']/1e9:.2f} GB") |
| print(f" elapsed: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings_lean": embeddings_lean_vol, |
| "/jsonl": jsonl_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=7200, |
| max_containers=16, |
| ) |
| def embed_vfdb_lean(jsonl_rel_paths) -> dict: |
| """VFDB-targeted version of embed_targeted_lean. Outputs under /embeddings_lean/vfdb/.""" |
| import json |
| import os |
| import time |
| import numpy as np |
| import torch |
|
|
| if isinstance(jsonl_rel_paths, str): |
| jsonl_rel_paths = [jsonl_rel_paths] |
|
|
| t_load_start = time.time() |
| evo2, device, module_dict = _get_evo2_only() |
| t_load = time.time() - t_load_start |
| layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] |
|
|
| cache: dict = {} |
| def make_hook(name): |
| def hook(module, inp, out): |
| cache[name] = (out[0] if isinstance(out, tuple) else out).detach() |
| return hook |
| handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_missing_jsonl = 0 |
| total_mb = 0.0 |
| per_region_times: list[float] = [] |
| seq_lens: list[int] = [] |
| try: |
| for jsonl_rel in jsonl_rel_paths: |
| src_path = f"/jsonl/{jsonl_rel}" |
| if not os.path.exists(src_path): |
| n_missing_jsonl += 1 |
| continue |
| with open(src_path) as f: |
| records = [json.loads(line) for line in f if line.strip()] |
| if not records: |
| continue |
|
|
| for rec in records: |
| |
| label_folder = "VIRULENCE" if rec["is_positive"] else "negative" |
| group = rec.get("mag_id") or rec.get("species") or "UNKNOWN" |
| region_id = rec["region_id"] |
| out_dir = f"/embeddings_lean/vfdb/{label_folder}/{group}" |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = f"{out_dir}/{region_id}.npz" |
|
|
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
|
|
| t_region = time.time() |
| seq = rec["sequence"] |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache[layer_names[0]].shape[1] |
| hidden = evo2.model.config.hidden_size |
|
|
| stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) |
| for i, ln in enumerate(layer_names): |
| stack[i] = cache[ln][0].to(torch.bfloat16).cpu() |
| del cache[ln] |
| torch.cuda.empty_cache() |
| stack_uint16 = stack.view(torch.uint16).numpy() |
|
|
| meta = {k: v for k, v in rec.items() if k != "sequence"} |
| np.savez( |
| out_path, |
| per_token_layer_activations_bf16=stack_uint16, |
| per_token_layer_activations_dtype="bfloat16", |
| layer_names=np.array(layer_names), |
| layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), |
| seq_len=np.int32(seq_len), |
| hidden_size=np.int32(hidden), |
| model_name="evo2_7b_262k", |
| metadata_json=np.array(json.dumps(meta)), |
| ) |
| sz = os.path.getsize(out_path) |
| total_mb += sz / 1e6 |
| n_done += 1 |
| per_region_times.append(time.time() - t_region) |
| seq_lens.append(seq_len) |
| del stack, stack_uint16, input_ids |
| torch.cuda.empty_cache() |
| finally: |
| for h in handles: |
| h.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| t_commit_start = time.time() |
| embeddings_lean_vol.commit() |
| t_commit = time.time() - t_commit_start |
|
|
| return { |
| "n_jsonls": len(jsonl_rel_paths), |
| "n_missing_jsonl": n_missing_jsonl, |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "total_mb": total_mb, |
| "model_load_s": t_load, |
| "commit_s": t_commit, |
| "per_region_s": per_region_times, |
| "seq_lens": seq_lens, |
| "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def pilot_vfdb_lean(target_records: int = 250, batch_size: int = 2): |
| """Run a small VFDB pilot to measure per-region time + commit time, then project full cost. |
| Picks species files greedily to land near `target_records` total — biased toward |
| small/medium species for cost containment, while including ≥1 medium-sized file |
| so the per-batch commit time reflects realistic output volume. |
| |
| modal run modal/evo2_inference.py::pilot_vfdb_lean |
| modal run modal/evo2_inference.py::pilot_vfdb_lean --target-records 500 --batch-size 4 |
| """ |
| import os, time, json |
|
|
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"VFDB JSONLs not found at {base}; run " |
| "extract_vfdb_negatives.py + the modal-prep step first.") |
|
|
| species_files = [] |
| for fname in sorted(os.listdir(base)): |
| if not fname.endswith(".jsonl"): |
| continue |
| path = os.path.join(base, fname) |
| with open(path) as f: |
| n = sum(1 for line in f if line.strip()) |
| species_files.append((n, fname)) |
| species_files.sort() |
|
|
| |
| |
| chosen = [] |
| total = 0 |
| for n, fname in species_files: |
| if total >= target_records: |
| break |
| if n > target_records * 2 and chosen: |
| continue |
| chosen.append((n, fname)) |
| total += n |
| if not chosen: |
| chosen = [species_files[0]] |
|
|
| print(f"[pilot-vfdb] selected {len(chosen)} species:") |
| total_records = 0 |
| for n, fname in chosen: |
| print(f" {fname:40s} {n} records") |
| total_records += n |
| print(f" total pilot records: {total_records}") |
|
|
| |
| print(f"\n[pilot-vfdb] uploading {len(chosen)} JSONLs to volume mgnify-targeted-jsonl ...") |
| upload_t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for _, fname in chosen: |
| local_path = os.path.join(base, fname) |
| remote_path = f"vfdb/{fname}" |
| batch.put_file(local_path, remote_path) |
| print(f" uploaded in {time.time()-upload_t0:.0f} s") |
|
|
| rel_paths = [f"vfdb/{fname}" for _, fname in chosen] |
| batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)] |
| print(f"\n[pilot-vfdb] {len(rel_paths)} JSONLs in {len(batches)} batch(es) of up to {batch_size}") |
|
|
| t0 = time.time() |
| results = list(embed_vfdb_lean.map(batches, return_exceptions=True)) |
| wall = time.time() - t0 |
|
|
| ok = [r for r in results if not isinstance(r, Exception)] |
| n_done = sum(r["n_done"] for r in ok) |
| total_mb = sum(r["total_mb"] for r in ok) |
| region_times = [t for r in ok for t in r.get("per_region_s") or []] |
| seq_lens = [s for r in ok for s in r.get("seq_lens") or []] |
| commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None] |
| load_times = [r["model_load_s"] for r in ok] |
|
|
| if not region_times: |
| print("ERROR: no records processed in pilot") |
| return |
|
|
| mean_region = sum(region_times) / len(region_times) |
| mean_commit = sum(commit_times) / max(len(commit_times), 1) |
| mean_load = sum(load_times) / max(len(load_times), 1) |
| mean_seqlen = sum(seq_lens) / len(seq_lens) |
|
|
| print(f"\n=== VFDB PILOT RESULTS ===") |
| print(f" records processed: {n_done}") |
| print(f" total output size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.2f} MB/record)") |
| print(f" wall clock (parallel): {wall:.0f} s across {len(batches)} batch(es)") |
| print(f" per-region inference: {mean_region:.2f} s avg " |
| f"(min {min(region_times):.2f}, max {max(region_times):.2f}, p95 {sorted(region_times)[int(len(region_times)*0.95)]:.2f})") |
| print(f" per-region seq len: {mean_seqlen:.0f} bp avg") |
| print(f" per-batch commit: {mean_commit:.2f} s avg " |
| f"(min {min(commit_times):.2f}, max {max(commit_times):.2f})") |
| print(f" per-call model load: {mean_load:.1f} s avg") |
|
|
| |
| full_records = 14695 |
| n_workers = 16 |
| h100_rate = 4.50 |
| full_n_jsonls = 34 |
| full_batches = (full_n_jsonls + batch_size - 1) // batch_size |
| inference_compute_s = full_records * mean_region |
| commit_compute_s = full_batches * mean_commit |
| cold_start_s = mean_load * min(n_workers, full_batches) |
| total_compute_s = inference_compute_s + commit_compute_s + cold_start_s |
| wall_proj = total_compute_s / min(n_workers, full_batches) |
| cost = (total_compute_s / 3600) * h100_rate |
| output_size_gb = (total_mb / max(n_done, 1)) * full_records / 1024 |
|
|
| print(f"\n PROJECTION ({full_records} records, {n_workers}× H100, batch_size={batch_size}, " |
| f"H100=${h100_rate:.2f}/hr):") |
| print(f" inference compute: {inference_compute_s:7.0f} s ({inference_compute_s/total_compute_s*100:.0f}%)") |
| print(f" commit compute: {commit_compute_s:7.0f} s ({commit_compute_s/total_compute_s*100:.0f}%)") |
| print(f" cold-start total: {cold_start_s:7.0f} s ({cold_start_s/total_compute_s*100:.0f}%)") |
| print(f" estimated wall clock: {wall_proj/60:5.1f} min") |
| print(f" estimated cost: ${cost:.2f}") |
| print(f" estimated total size: {output_size_gb:.1f} GB") |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/jsonl": jsonl_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_vfdb_lean(batch_size: int = 2) -> dict: |
| """Walks /jsonl/vfdb/, batches paths by `batch_size`, fans out to embed_vfdb_lean. |
| Mirror of orchestrate_lean_full but for VFDB.""" |
| import os |
| jsonl_paths = [] |
| for root, _, files in os.walk("/jsonl/vfdb"): |
| for fname in files: |
| if fname.endswith(".jsonl"): |
| rel = os.path.relpath(os.path.join(root, fname), "/jsonl") |
| jsonl_paths.append(rel) |
| jsonl_paths.sort() |
| batches = [jsonl_paths[i:i + batch_size] for i in range(0, len(jsonl_paths), batch_size)] |
| print(f"[orchestrator-vfdb] {len(jsonl_paths)} JSONLs → {len(batches)} batches of up to {batch_size}") |
|
|
| n_done = 0 |
| n_skipped = 0 |
| total_mb = 0.0 |
| errors = 0 |
| region_time_samples: list[float] = [] |
| commit_time_samples: list[float] = [] |
|
|
| for i, r in enumerate(embed_vfdb_lean.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| errors += 1 |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| total_mb += r.get("total_mb", 0.0) |
| if r.get("per_region_s"): |
| region_time_samples.extend(r["per_region_s"]) |
| if r.get("commit_s") is not None: |
| commit_time_samples.append(r["commit_s"]) |
| mean_t = sum(region_time_samples) / max(len(region_time_samples), 1) |
| mean_c = sum(commit_time_samples) / max(len(commit_time_samples), 1) |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " |
| f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f} mean_commit_s={mean_c:.2f}") |
|
|
| return { |
| "jsonls": len(jsonl_paths), |
| "batches": len(batches), |
| "regions_done": n_done, |
| "regions_skipped": n_skipped, |
| "errors": errors, |
| "total_mb": total_mb, |
| "mean_region_s": sum(region_time_samples) / max(len(region_time_samples), 1), |
| "mean_commit_s": sum(commit_time_samples) / max(len(commit_time_samples), 1), |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def run_vfdb_lean(batch_size: int = 2): |
| """Upload all VFDB JSONLs to volume + run the full lean pipeline. |
| |
| modal run --detach modal/evo2_inference.py::run_vfdb_lean |
| modal run --detach modal/evo2_inference.py::run_vfdb_lean --batch-size 4 |
| """ |
| import os, time |
|
|
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"VFDB JSONLs not found at {base}") |
|
|
| jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) |
| print(f"[run-vfdb] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl volume ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for fname in jsonls: |
| batch.put_file(os.path.join(base, fname), f"vfdb/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| print(f"\n[run-vfdb] launching orchestrator (batch_size={batch_size})") |
| r = orchestrate_vfdb_lean.remote(batch_size=batch_size) |
| print("\n=== DONE ===") |
| print(f" JSONL files: {r['jsonls']} in {r['batches']} batches") |
| print(f" regions saved: {r['regions_done']}") |
| print(f" regions skipped: {r['regions_skipped']} (already on volume)") |
| print(f" errors: {r['errors']}") |
| print(f" total volume size: {r['total_mb']/1024:.1f} GB") |
| print(f" mean per-region: {r['mean_region_s']:.2f} s") |
| print(f" mean per-batch commit: {r['mean_commit_s']:.2f} s") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| embeddings_l26_vfdb_vol = modal.Volume.from_name( |
| "mgnify-embeddings-l26-vfdb", create_if_missing=True, |
| ) |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("numpy"), |
| cpu=2, |
| volumes={ |
| "/in": embeddings_lean_vol, |
| "/out": embeddings_l26_vfdb_vol, |
| }, |
| timeout=3600, |
| max_containers=16, |
| ) |
| def slice_l26_vfdb_batch(rel_paths: list[str]) -> dict: |
| """Slice layer 26 from each VFDB lean npz. Same schema as slice_l26_from_lean_batch.""" |
| import os |
| import numpy as np |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_errors = 0 |
| total_mb = 0.0 |
| for rel in rel_paths: |
| in_path = f"/in/{rel}" |
| out_path = f"/out/{rel}" |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
| if not os.path.exists(in_path): |
| n_errors += 1 |
| continue |
| try: |
| with np.load(in_path, allow_pickle=False) as d: |
| stack = d["per_token_layer_activations_bf16"] |
| layer_indices = list(int(x) for x in d["layer_indices"]) |
| pos = layer_indices.index(26) |
| l26 = stack[pos].copy() |
| passthrough = { |
| "seq_len": d["seq_len"], |
| "hidden_size": d["hidden_size"], |
| "model_name": d["model_name"], |
| "metadata_json": d["metadata_json"], |
| } |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| np.savez( |
| out_path, |
| layer26_activations_bf16=l26, |
| layer26_dtype="bfloat16", |
| source_layer_index=np.int32(26), |
| source_layer_name="blocks-26", |
| **passthrough, |
| ) |
| total_mb += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| except Exception as e: |
| print(f" ERROR on {rel}: {e}") |
| n_errors += 1 |
| embeddings_l26_vfdb_vol.commit() |
| return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb} |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/in": embeddings_lean_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_l26_vfdb_slice(batch_size: int = 100) -> dict: |
| """Walk only /in/vfdb/, batch, fan out to slice_l26_vfdb_batch.""" |
| import os |
| paths = [] |
| for root, _, files in os.walk("/in/vfdb"): |
| for fname in files: |
| if fname.endswith(".npz"): |
| rel = os.path.relpath(os.path.join(root, fname), "/in") |
| paths.append(rel) |
| paths.sort() |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator-l26-vfdb] {len(paths)} lean npz → {len(batches)} batches of {batch_size}") |
|
|
| n_done = 0 |
| n_skipped = 0 |
| n_errors = 0 |
| total_mb_out = 0.0 |
| for i, r in enumerate(slice_l26_vfdb_batch.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| n_errors += r.get("n_errors", 0) |
| total_mb_out += r.get("total_mb", 0.0) |
| if (i + 1) % 5 == 0 or (i + 1) == len(batches): |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB") |
| return { |
| "files_total": len(paths), |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "n_errors": n_errors, |
| "total_mb_out": total_mb_out, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), |
| cpu=8, |
| volumes={"/vol": embeddings_l26_vfdb_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=86400, |
| ) |
| def upload_l26_vfdb_to_hf(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", private: bool = False, |
| num_workers: int = 8) -> dict: |
| """Push the VFDB layer-26 slice to HF Datasets. Resumes if partial.""" |
| import os |
| import time |
| from huggingface_hub import HfApi, login |
|
|
| |
| |
| token = None |
| for k, v in os.environ.items(): |
| if k.startswith("hf_") or k.startswith("HF_"): |
| if k.startswith("hf_") and len(k) > 30: |
| token = k |
| break |
| if v.startswith("hf_") and len(v) > 30: |
| token = v |
| break |
| if not token: |
| token = os.environ.get("HF_TOKEN") |
| login(token=token) |
|
|
| api = HfApi() |
| user = api.whoami()["name"] |
| full_repo = f"{user}/{repo_name}" |
| api.create_repo(full_repo, repo_type="dataset", private=private, exist_ok=True) |
| print(f"[hf-push-vfdb] uploading /vol → {full_repo} (private={private}, workers={num_workers})") |
|
|
| t0 = time.time() |
| api.upload_large_folder( |
| folder_path="/vol", |
| repo_id=full_repo, |
| repo_type="dataset", |
| num_workers=num_workers, |
| ) |
| elapsed = time.time() - t0 |
|
|
| n_files = 0 |
| bytes_total = 0 |
| for root, _, files in os.walk("/vol"): |
| for fname in files: |
| if fname.endswith(".npz"): |
| n_files += 1 |
| bytes_total += os.path.getsize(os.path.join(root, fname)) |
|
|
| return { |
| "repo_url": f"https://huggingface.co/datasets/{full_repo}", |
| "n_files": n_files, |
| "bytes_total": bytes_total, |
| "elapsed_s": elapsed, |
| "private": private, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim(), |
| cpu=1, |
| volumes={ |
| "/embeddings_lean": embeddings_lean_vol, |
| "/l26_vfdb": embeddings_l26_vfdb_vol, |
| }, |
| timeout=3600, |
| ) |
| def wipe_vfdb_outputs() -> dict: |
| """Remove stale /embeddings_lean/vfdb/ and /l26_vfdb/* before a clean re-run. |
| Done after switching positive region_id from source_accession to vfg_id — |
| otherwise old files contaminate slice + HF push.""" |
| import os |
| import shutil |
| import time |
| summary = {} |
| for path, label, vol in [ |
| ("/embeddings_lean/vfdb", "lean_vfdb", embeddings_lean_vol), |
| ("/l26_vfdb", "l26_vfdb", embeddings_l26_vfdb_vol), |
| ]: |
| t0 = time.time() |
| if os.path.exists(path): |
| n_before = sum(1 for _, _, fs in os.walk(path) for _ in fs) |
| |
| for entry in os.listdir(path): |
| p = os.path.join(path, entry) |
| if os.path.isdir(p): |
| shutil.rmtree(p) |
| else: |
| os.remove(p) |
| summary[label] = {"existed": True, "files_removed": n_before, "elapsed_s": time.time() - t0} |
| else: |
| summary[label] = {"existed": False} |
| vol.commit() |
| return summary |
|
|
|
|
| @app.local_entrypoint() |
| def run_vfdb_full(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", |
| private: bool = False, |
| embed_batch_size: int = 2, |
| slice_batch_size: int = 100, |
| wipe_first: bool = True): |
| """One-shot: optionally wipe stale outputs, embed, slice l26, push to HF. |
| Use --no-wipe-first to skip the wipe (e.g., for re-runs after an interrupted push). |
| |
| modal run --detach modal/evo2_inference.py::run_vfdb_full |
| """ |
| import os, time |
|
|
| if wipe_first: |
| print("[0/4] wiping stale VFDB outputs (vfdb/ and l26_vfdb/) ...") |
| w = wipe_vfdb_outputs.remote() |
| for label, info in w.items(): |
| if info.get("existed"): |
| print(f" {label}: removed {info['files_removed']} files in {info['elapsed_s']:.0f} s") |
| else: |
| print(f" {label}: nothing to remove") |
|
|
| |
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/vfdb_modal_ready" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"VFDB JSONLs not found at {base}") |
| jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) |
| print(f"\n[1/4] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for fname in jsonls: |
| batch.put_file(os.path.join(base, fname), f"vfdb/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| |
| print(f"\n[2/4] running embed_vfdb_lean (batch_size={embed_batch_size}) ...") |
| r = orchestrate_vfdb_lean.remote(batch_size=embed_batch_size) |
| print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") |
| print(f" total volume size: {r['total_mb']/1024:.1f} GB") |
|
|
| |
| print(f"\n[3/4] slicing layer 26 (batch_size={slice_batch_size}) ...") |
| s = orchestrate_l26_vfdb_slice.remote(batch_size=slice_batch_size) |
| print(f" files: {s['files_total']}, done: {s['n_done']}, skipped: {s['n_skipped']}, errors: {s['n_errors']}") |
| print(f" l26 vol size: {s['total_mb_out']/1024:.2f} GB") |
|
|
| |
| print(f"\n[4/4] pushing to HF as {repo_name} (private={private}) ...") |
| u = upload_l26_vfdb_to_hf.remote(repo_name=repo_name, private=private) |
| print(f"\n=== ALL DONE ===") |
| print(f" HF repo: {u['repo_url']}") |
| print(f" files: {u['n_files']}") |
| print(f" size: {u['bytes_total']/1e9:.2f} GB") |
| print(f" upload: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings_lean": embeddings_lean_vol, |
| "/jsonl": jsonl_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=3600, |
| max_containers=16, |
| ) |
| def embed_qual_lean(jsonl_rel_paths) -> dict: |
| """Mirror of embed_vfdb_lean for the qualitative sample. Outputs at |
| /embeddings_lean/qual/<label_group>/<category_slug>/<region_id>.npz.""" |
| import json, os, time |
| import numpy as np |
| import torch |
|
|
| if isinstance(jsonl_rel_paths, str): |
| jsonl_rel_paths = [jsonl_rel_paths] |
|
|
| t_load_start = time.time() |
| evo2, device, module_dict = _get_evo2_only() |
| t_load = time.time() - t_load_start |
| layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] |
|
|
| cache: dict = {} |
| def make_hook(name): |
| def hook(module, inp, out): |
| cache[name] = (out[0] if isinstance(out, tuple) else out).detach() |
| return hook |
| handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] |
|
|
| n_done = n_skipped = n_missing_jsonl = 0 |
| total_mb = 0.0 |
| per_region_times: list[float] = [] |
| try: |
| for jsonl_rel in jsonl_rel_paths: |
| src_path = f"/jsonl/{jsonl_rel}" |
| if not os.path.exists(src_path): |
| n_missing_jsonl += 1 |
| continue |
| with open(src_path) as f: |
| records = [json.loads(line) for line in f if line.strip()] |
| for rec in records: |
| group = rec.get("label_group") or "UNKNOWN" |
| slug = rec.get("mag_id") or "UNKNOWN" |
| region_id = rec["region_id"] |
| out_dir = f"/embeddings_lean/qual/{group}/{slug}" |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = f"{out_dir}/{region_id}.npz" |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
| t_region = time.time() |
| seq = rec["sequence"] |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache[layer_names[0]].shape[1] |
| hidden = evo2.model.config.hidden_size |
| stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) |
| for i, ln in enumerate(layer_names): |
| stack[i] = cache[ln][0].to(torch.bfloat16).cpu() |
| del cache[ln] |
| torch.cuda.empty_cache() |
| stack_uint16 = stack.view(torch.uint16).numpy() |
| meta = {k: v for k, v in rec.items() if k != "sequence"} |
| np.savez( |
| out_path, |
| per_token_layer_activations_bf16=stack_uint16, |
| per_token_layer_activations_dtype="bfloat16", |
| layer_names=np.array(layer_names), |
| layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), |
| seq_len=np.int32(seq_len), |
| hidden_size=np.int32(hidden), |
| model_name="evo2_7b_262k", |
| metadata_json=np.array(json.dumps(meta)), |
| ) |
| total_mb += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| per_region_times.append(time.time() - t_region) |
| del stack, stack_uint16, input_ids |
| torch.cuda.empty_cache() |
| finally: |
| for h in handles: |
| h.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| t_commit_start = time.time() |
| embeddings_lean_vol.commit() |
| t_commit = time.time() - t_commit_start |
| return { |
| "n_jsonls": len(jsonl_rel_paths), |
| "n_missing_jsonl": n_missing_jsonl, |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "total_mb": total_mb, |
| "model_load_s": t_load, |
| "commit_s": t_commit, |
| "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim(), |
| cpu=1, |
| volumes={"/embeddings_lean": embeddings_lean_vol}, |
| timeout=1800, |
| ) |
| def rename_qual_to_small() -> dict: |
| """Rename /embeddings_lean/qual → /embeddings_lean/small in-place.""" |
| import os, shutil, time |
| src = "/embeddings_lean/qual" |
| dst = "/embeddings_lean/small" |
| if not os.path.exists(src): |
| return {"renamed": False, "reason": f"{src} does not exist"} |
| if os.path.exists(dst): |
| return {"renamed": False, "reason": f"{dst} already exists"} |
| n_files = sum(1 for _, _, fs in os.walk(src) for _ in fs) |
| t0 = time.time() |
| shutil.move(src, dst) |
| embeddings_lean_vol.commit() |
| return {"renamed": True, "src": src, "dst": dst, "files_moved": n_files, "elapsed_s": time.time() - t0} |
|
|
|
|
| @app.local_entrypoint() |
| def rename_qual_small(): |
| """One-shot: rename qual/ → small/ on the lean volume.""" |
| r = rename_qual_to_small.remote() |
| print(r) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings_lean": embeddings_lean_vol, |
| "/jsonl": jsonl_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=7200, |
| max_containers=16, |
| ) |
| def embed_syngenome_lean(jsonl_rel_paths) -> dict: |
| """SynGenome version of embed_vfdb_lean. All records are AMR positives.""" |
| import json, os, time |
| import numpy as np |
| import torch |
|
|
| if isinstance(jsonl_rel_paths, str): |
| jsonl_rel_paths = [jsonl_rel_paths] |
|
|
| t_load_start = time.time() |
| evo2, device, module_dict = _get_evo2_only() |
| t_load = time.time() - t_load_start |
| layer_names = [f"blocks-{i}" for i in LEAN_LAYERS] |
|
|
| cache: dict = {} |
| def make_hook(name): |
| def hook(module, inp, out): |
| cache[name] = (out[0] if isinstance(out, tuple) else out).detach() |
| return hook |
| handles = [module_dict[ln].register_forward_hook(make_hook(ln)) for ln in layer_names] |
|
|
| n_done = n_skipped = n_missing_jsonl = 0 |
| total_mb = 0.0 |
| per_region_times: list[float] = [] |
| try: |
| for jsonl_rel in jsonl_rel_paths: |
| src_path = f"/jsonl/{jsonl_rel}" |
| if not os.path.exists(src_path): |
| n_missing_jsonl += 1 |
| continue |
| with open(src_path) as f: |
| records = [json.loads(line) for line in f if line.strip()] |
| for rec in records: |
| drug_class = rec.get("mag_id") or "UNKNOWN" |
| region_id = rec["region_id"] |
| out_dir = f"/embeddings_lean/syngenome/AMR/{drug_class}" |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = f"{out_dir}/{region_id}.npz" |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
| t_region = time.time() |
| seq = rec["sequence"] |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache[layer_names[0]].shape[1] |
| hidden = evo2.model.config.hidden_size |
| stack = torch.zeros(len(layer_names), seq_len, hidden, dtype=torch.bfloat16) |
| for i, ln in enumerate(layer_names): |
| stack[i] = cache[ln][0].to(torch.bfloat16).cpu() |
| del cache[ln] |
| torch.cuda.empty_cache() |
| stack_uint16 = stack.view(torch.uint16).numpy() |
| meta = {k: v for k, v in rec.items() if k != "sequence"} |
| np.savez( |
| out_path, |
| per_token_layer_activations_bf16=stack_uint16, |
| per_token_layer_activations_dtype="bfloat16", |
| layer_names=np.array(layer_names), |
| layer_indices=np.array(LEAN_LAYERS, dtype=np.int32), |
| seq_len=np.int32(seq_len), |
| hidden_size=np.int32(hidden), |
| model_name="evo2_7b_262k", |
| metadata_json=np.array(json.dumps(meta)), |
| ) |
| total_mb += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| per_region_times.append(time.time() - t_region) |
| del stack, stack_uint16, input_ids |
| torch.cuda.empty_cache() |
| finally: |
| for h in handles: |
| h.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| t_commit_start = time.time() |
| embeddings_lean_vol.commit() |
| t_commit = time.time() - t_commit_start |
| return { |
| "n_jsonls": len(jsonl_rel_paths), |
| "n_missing_jsonl": n_missing_jsonl, |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "total_mb": total_mb, |
| "model_load_s": t_load, |
| "commit_s": t_commit, |
| "per_region_s": per_region_times, |
| "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/jsonl": jsonl_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_syngenome_lean(batch_size: int = 2) -> dict: |
| """Walks /jsonl/syngenome/, batches, fans out to embed_syngenome_lean.""" |
| import os |
| paths = [] |
| for root, _, files in os.walk("/jsonl/syngenome"): |
| for fname in files: |
| if fname.endswith(".jsonl"): |
| rel = os.path.relpath(os.path.join(root, fname), "/jsonl") |
| paths.append(rel) |
| paths.sort() |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator-syngenome] {len(paths)} JSONLs → {len(batches)} batches of up to {batch_size}") |
|
|
| n_done = n_skipped = errors = 0 |
| total_mb = 0.0 |
| region_times: list[float] = [] |
| commit_times: list[float] = [] |
| for i, r in enumerate(embed_syngenome_lean.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| errors += 1 |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| total_mb += r.get("total_mb", 0.0) |
| if r.get("per_region_s"): region_times.extend(r["per_region_s"]) |
| if r.get("commit_s") is not None: commit_times.append(r["commit_s"]) |
| mean_t = sum(region_times) / max(len(region_times), 1) |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " |
| f"{total_mb/1024:.1f} GB mean_region_s={mean_t:.2f}") |
| return {"jsonls": len(paths), "batches": len(batches), |
| "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, |
| "total_mb": total_mb, |
| "mean_region_s": sum(region_times) / max(len(region_times), 1)} |
|
|
|
|
| |
| |
| embeddings_l26_syngenome_vol = modal.Volume.from_name( |
| "mgnify-embeddings-l26-syngenome", create_if_missing=True, |
| ) |
|
|
|
|
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={ |
| "/root/.cache/huggingface": weights_vol, |
| "/embeddings_l26_syn": embeddings_l26_syngenome_vol, |
| "/jsonl": jsonl_vol, |
| }, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=7200, |
| max_containers=16, |
| ) |
| def embed_syngenome_l26(jsonl_rel_paths) -> dict: |
| """Layer-26-only embed for SynGenome AMRs. Same forward pass as the lean |
| variant but only blocks-26 is hooked + saved. ~5× less storage.""" |
| import json, os, time |
| import numpy as np |
| import torch |
|
|
| if isinstance(jsonl_rel_paths, str): |
| jsonl_rel_paths = [jsonl_rel_paths] |
|
|
| t_load_start = time.time() |
| evo2, device, module_dict = _get_evo2_only() |
| t_load = time.time() - t_load_start |
| layer_name = "blocks-26" |
|
|
| cache: dict = {} |
| def hook(module, inp, out): |
| cache[layer_name] = (out[0] if isinstance(out, tuple) else out).detach() |
| handle = module_dict[layer_name].register_forward_hook(hook) |
|
|
| n_done = n_skipped = n_missing_jsonl = 0 |
| total_mb = 0.0 |
| per_region_times: list[float] = [] |
| try: |
| for jsonl_rel in jsonl_rel_paths: |
| src_path = f"/jsonl/{jsonl_rel}" |
| if not os.path.exists(src_path): |
| n_missing_jsonl += 1 |
| continue |
| with open(src_path) as f: |
| records = [json.loads(line) for line in f if line.strip()] |
| for rec in records: |
| |
| |
| top = "AMR" if rec.get("is_positive", True) else "negative" |
| group = rec.get("mag_id") or "UNKNOWN" |
| region_id = rec["region_id"] |
| out_dir = f"/embeddings_l26_syn/{top}/{group}" |
| os.makedirs(out_dir, exist_ok=True) |
| out_path = f"{out_dir}/{region_id}.npz" |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
| t_region = time.time() |
| seq = rec["sequence"] |
| cache.clear() |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| seq_len = cache[layer_name].shape[1] |
| hidden = evo2.model.config.hidden_size |
| l26 = cache[layer_name][0].to(torch.bfloat16).cpu() |
| del cache[layer_name] |
| torch.cuda.empty_cache() |
| l26_uint16 = l26.view(torch.uint16).numpy() |
| meta = {k: v for k, v in rec.items() if k != "sequence"} |
| np.savez( |
| out_path, |
| layer26_activations_bf16=l26_uint16, |
| layer26_dtype="bfloat16", |
| source_layer_index=np.int32(26), |
| source_layer_name="blocks-26", |
| seq_len=np.int32(seq_len), |
| hidden_size=np.int32(hidden), |
| model_name="evo2_7b_262k", |
| metadata_json=np.array(json.dumps(meta)), |
| ) |
| total_mb += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| per_region_times.append(time.time() - t_region) |
| del l26, l26_uint16, input_ids |
| torch.cuda.empty_cache() |
| finally: |
| handle.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| t_commit_start = time.time() |
| embeddings_l26_syngenome_vol.commit() |
| t_commit = time.time() - t_commit_start |
| return { |
| "n_jsonls": len(jsonl_rel_paths), |
| "n_missing_jsonl": n_missing_jsonl, |
| "n_done": n_done, |
| "n_skipped": n_skipped, |
| "total_mb": total_mb, |
| "model_load_s": t_load, |
| "commit_s": t_commit, |
| "per_region_s": per_region_times, |
| "mean_per_region_s": (sum(per_region_times) / len(per_region_times)) if per_region_times else None, |
| } |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/jsonl": jsonl_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_syngenome_l26(batch_size: int = 2) -> dict: |
| """Walks /jsonl/syngenome/, batches, fans out to embed_syngenome_l26.""" |
| import os |
| paths = [] |
| for root, _, files in os.walk("/jsonl/syngenome"): |
| for fname in files: |
| if fname.endswith(".jsonl"): |
| rel = os.path.relpath(os.path.join(root, fname), "/jsonl") |
| paths.append(rel) |
| paths.sort() |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator-syngenome-l26] {len(paths)} JSONLs → {len(batches)} batches of up to {batch_size}") |
|
|
| n_done = n_skipped = errors = 0 |
| total_mb = 0.0 |
| region_times: list[float] = [] |
| commit_times: list[float] = [] |
| for i, r in enumerate(embed_syngenome_l26.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| errors += 1 |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| total_mb += r.get("total_mb", 0.0) |
| if r.get("per_region_s"): region_times.extend(r["per_region_s"]) |
| if r.get("commit_s") is not None: commit_times.append(r["commit_s"]) |
| mean_t = sum(region_times) / max(len(region_times), 1) |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " |
| f"{total_mb/1024:.2f} GB mean_region_s={mean_t:.2f}") |
| return {"jsonls": len(paths), "batches": len(batches), |
| "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, |
| "total_mb": total_mb, |
| "mean_region_s": sum(region_times) / max(len(region_times), 1)} |
|
|
|
|
| @app.local_entrypoint() |
| def run_syngenome_l26(batch_size: int = 2): |
| """Upload SynGenome AMR JSONLs + run layer-26-only embed.""" |
| import os, time |
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"SynGenome JSONLs not found at {base}") |
| jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) |
| print(f"[run-syngenome-l26] uploading {len(jsonls)} JSONLs ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for fname in jsonls: |
| batch.put_file(os.path.join(base, fname), f"syngenome/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| print(f"\n[run-syngenome-l26] orchestrator (batch_size={batch_size})") |
| r = orchestrate_syngenome_l26.remote(batch_size=batch_size) |
| print(f"\n=== DONE ===") |
| print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") |
| print(f" total size: {r['total_mb']/1024:.2f} GB") |
| print(f" mean per-region: {r['mean_region_s']:.2f} s") |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/jsonl": jsonl_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_syngenome_l26_neg(batch_size: int = 2) -> dict: |
| """Walks /jsonl/syngenome_neg/, batches, fans out to embed_syngenome_l26.""" |
| import os |
| paths = [] |
| for root, _, files in os.walk("/jsonl/syngenome_neg"): |
| for fname in files: |
| if fname.endswith(".jsonl"): |
| rel = os.path.relpath(os.path.join(root, fname), "/jsonl") |
| paths.append(rel) |
| paths.sort() |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator-syngenome-l26-neg] {len(paths)} JSONLs → {len(batches)} batches") |
|
|
| n_done = n_skipped = errors = 0 |
| total_mb = 0.0 |
| region_times: list[float] = [] |
| for i, r in enumerate(embed_syngenome_l26.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| errors += 1 |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| total_mb += r.get("total_mb", 0.0) |
| if r.get("per_region_s"): region_times.extend(r["per_region_s"]) |
| mean_t = sum(region_times) / max(len(region_times), 1) |
| print(f" [{i+1}/{len(batches)}] done={n_done} skipped={n_skipped} errors={errors} " |
| f"{total_mb/1024:.2f} GB mean_region_s={mean_t:.2f}") |
| return {"jsonls": len(paths), "batches": len(batches), |
| "regions_done": n_done, "regions_skipped": n_skipped, "errors": errors, |
| "total_mb": total_mb, |
| "mean_region_s": sum(region_times) / max(len(region_times), 1)} |
|
|
|
|
| @app.local_entrypoint() |
| def run_syngenome_l26_neg(batch_size: int = 2): |
| """Upload SynGenome NEGATIVE JSONLs + run layer-26-only embed. |
| Run AFTER run_syngenome_l26 finishes to avoid GPU contention.""" |
| import os, time |
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome_neg" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"SynGenome negative JSONLs not found at {base}") |
| jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) |
| print(f"[run-syngenome-l26-neg] uploading {len(jsonls)} JSONLs ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for fname in jsonls: |
| batch.put_file(os.path.join(base, fname), f"syngenome_neg/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| print(f"\n[run-syngenome-l26-neg] orchestrator (batch_size={batch_size})") |
| r = orchestrate_syngenome_l26_neg.remote(batch_size=batch_size) |
| print(f"\n=== DONE ===") |
| print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") |
| print(f" total size: {r['total_mb']/1024:.2f} GB") |
| print(f" mean per-region: {r['mean_region_s']:.2f} s") |
|
|
|
|
| @app.local_entrypoint() |
| def pilot_syngenome_lean(target_records: int = 250, batch_size: int = 2): |
| """Run a small SynGenome pilot to measure per-region time + commit time, then |
| project full-run cost. Picks drug-class JSONLs greedily up to ~target_records, |
| biased toward small classes for cost containment. |
| |
| modal run modal/evo2_inference.py::pilot_syngenome_lean |
| modal run modal/evo2_inference.py::pilot_syngenome_lean --target-records 400 --batch-size 4 |
| """ |
| import os, time |
|
|
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"SynGenome JSONLs not found at {base}; run scripts/sample_syngenome_amr.py first.") |
|
|
| species_files = [] |
| for fname in sorted(os.listdir(base)): |
| if not fname.endswith(".jsonl"): continue |
| with open(os.path.join(base, fname)) as f: |
| n = sum(1 for line in f if line.strip()) |
| species_files.append((n, fname)) |
| species_files.sort() |
|
|
| chosen = [] |
| total = 0 |
| for n, fname in species_files: |
| if total >= target_records: |
| break |
| if n > target_records * 2 and chosen: |
| continue |
| chosen.append((n, fname)) |
| total += n |
| if not chosen: |
| chosen = [species_files[0]] |
|
|
| print(f"[pilot-syngenome] selected {len(chosen)} drug-class files:") |
| for n, fname in chosen: |
| print(f" {fname:30s} {n} records") |
| print(f" total pilot records: {total}") |
|
|
| print(f"\n[pilot-syngenome] uploading to mgnify-targeted-jsonl ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for _, fname in chosen: |
| batch.put_file(os.path.join(base, fname), f"syngenome/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| rel_paths = [f"syngenome/{fname}" for _, fname in chosen] |
| batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)] |
| print(f"\n[pilot-syngenome] {len(rel_paths)} JSONLs in {len(batches)} batch(es) of {batch_size}") |
|
|
| t0 = time.time() |
| results = list(embed_syngenome_lean.map(batches, return_exceptions=True)) |
| wall = time.time() - t0 |
|
|
| ok = [r for r in results if not isinstance(r, Exception)] |
| n_done = sum(r["n_done"] for r in ok) |
| total_mb = sum(r["total_mb"] for r in ok) |
| region_times = [t for r in ok for t in r.get("per_region_s") or []] |
| commit_times = [r["commit_s"] for r in ok if r.get("commit_s") is not None] |
| load_times = [r["model_load_s"] for r in ok] |
|
|
| if not region_times: |
| print("ERROR: no records processed") |
| return |
|
|
| mean_region = sum(region_times) / len(region_times) |
| mean_commit = sum(commit_times) / max(len(commit_times), 1) |
| mean_load = sum(load_times) / max(len(load_times), 1) |
|
|
| |
| import json |
| seq_lens_local = [] |
| for _, fname in chosen: |
| with open(os.path.join(base, fname)) as f: |
| for line in f: |
| if line.strip(): |
| seq_lens_local.append(json.loads(line).get("cds_length", 0)) |
| pilot_avg_seqlen = sum(seq_lens_local) / max(len(seq_lens_local), 1) |
|
|
| print(f"\n=== SYNGENOME PILOT RESULTS ===") |
| print(f" records processed: {n_done}") |
| print(f" output size: {total_mb:.0f} MB ({total_mb/max(n_done,1):.1f} MB/record)") |
| print(f" wall clock: {wall:.0f} s across {len(batches)} batch(es)") |
| print(f" per-region inference: {mean_region:.2f} s avg " |
| f"(min {min(region_times):.2f}, max {max(region_times):.2f}, p95 {sorted(region_times)[int(len(region_times)*0.95)]:.2f})") |
| print(f" per-region seq len: {pilot_avg_seqlen:.0f} bp avg") |
| print(f" per-batch commit: {mean_commit:.2f} s avg") |
| print(f" per-call model load: {mean_load:.1f} s avg") |
|
|
| |
| full_records = 8000 |
| n_workers = 16 |
| h100_rate = 4.50 |
| full_n_jsonls = 13 |
| full_batches = (full_n_jsonls + batch_size - 1) // batch_size |
|
|
| |
| |
| full_avg_seqlen = 5000 |
| length_factor = full_avg_seqlen / max(pilot_avg_seqlen, 1) |
| inference_compute_s = full_records * mean_region * length_factor |
| commit_compute_s = full_batches * mean_commit |
| cold_start_s = mean_load * min(n_workers, full_batches) |
| total_compute_s = inference_compute_s + commit_compute_s + cold_start_s |
| wall_proj = total_compute_s / min(n_workers, full_batches) |
| cost = (total_compute_s / 3600) * h100_rate |
| output_size_gb = (total_mb / max(n_done, 1)) * full_records * length_factor / 1024 |
|
|
| print(f"\n PROJECTION ({full_records} records, {n_workers}× H100, batch_size={batch_size}, " |
| f"H100=${h100_rate:.2f}/hr, length_factor={length_factor:.2f}):") |
| print(f" inference compute: {inference_compute_s:7.0f} s") |
| print(f" commit compute: {commit_compute_s:7.0f} s") |
| print(f" cold-start total: {cold_start_s:7.0f} s") |
| print(f" estimated wall clock: {wall_proj/60:5.1f} min") |
| print(f" estimated cost: ${cost:.2f}") |
| print(f" estimated total size: {output_size_gb:.1f} GB") |
|
|
|
|
| @app.local_entrypoint() |
| def run_syngenome_lean(batch_size: int = 2): |
| """Upload SynGenome JSONLs + run lean embed.""" |
| import os, time |
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/syngenome" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"SynGenome JSONLs not found at {base}; run scripts/sample_syngenome_amr.py first.") |
| jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) |
| print(f"[run-syngenome] uploading {len(jsonls)} JSONLs ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for fname in jsonls: |
| batch.put_file(os.path.join(base, fname), f"syngenome/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| print(f"\n[run-syngenome] orchestrator (batch_size={batch_size})") |
| r = orchestrate_syngenome_lean.remote(batch_size=batch_size) |
| print(f"\n=== DONE ===") |
| print(f" regions saved: {r['regions_done']} (skipped {r['regions_skipped']}, errors {r['errors']})") |
| print(f" total size: {r['total_mb']/1024:.1f} GB") |
| print(f" mean per-region: {r['mean_region_s']:.2f} s") |
|
|
|
|
| |
| embeddings_l26_small_vol = modal.Volume.from_name( |
| "mgnify-embeddings-l26-small", create_if_missing=True, |
| ) |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("numpy"), |
| cpu=2, |
| volumes={ |
| "/in": embeddings_lean_vol, |
| "/out": embeddings_l26_small_vol, |
| }, |
| timeout=3600, |
| max_containers=8, |
| ) |
| def slice_l26_small_batch(rel_paths: list[str]) -> dict: |
| import os |
| import numpy as np |
| n_done = n_skipped = n_errors = 0 |
| total_mb = 0.0 |
| for rel in rel_paths: |
| in_path = f"/in/{rel}" |
| out_path = f"/out/{rel}" |
| if os.path.exists(out_path): |
| n_skipped += 1 |
| continue |
| if not os.path.exists(in_path): |
| n_errors += 1 |
| continue |
| try: |
| with np.load(in_path, allow_pickle=False) as d: |
| stack = d["per_token_layer_activations_bf16"] |
| layer_indices = list(int(x) for x in d["layer_indices"]) |
| pos = layer_indices.index(26) |
| l26 = stack[pos].copy() |
| passthrough = { |
| "seq_len": d["seq_len"], |
| "hidden_size": d["hidden_size"], |
| "model_name": d["model_name"], |
| "metadata_json": d["metadata_json"], |
| } |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| np.savez( |
| out_path, |
| layer26_activations_bf16=l26, |
| layer26_dtype="bfloat16", |
| source_layer_index=np.int32(26), |
| source_layer_name="blocks-26", |
| **passthrough, |
| ) |
| total_mb += os.path.getsize(out_path) / 1e6 |
| n_done += 1 |
| except Exception as e: |
| print(f" ERROR on {rel}: {e}") |
| n_errors += 1 |
| embeddings_l26_small_vol.commit() |
| return {"n_done": n_done, "n_skipped": n_skipped, "n_errors": n_errors, "total_mb": total_mb} |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("modal"), |
| cpu=1, |
| volumes={"/in": embeddings_lean_vol}, |
| timeout=86400, |
| ) |
| def orchestrate_l26_small_slice(batch_size: int = 100) -> dict: |
| import os |
| paths = [] |
| for root, _, files in os.walk("/in/small"): |
| for fname in files: |
| if fname.endswith(".npz"): |
| rel = os.path.relpath(os.path.join(root, fname), "/in") |
| paths.append(rel) |
| paths.sort() |
| batches = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| print(f"[orchestrator-l26-small] {len(paths)} npz → {len(batches)} batches") |
| n_done = n_skipped = n_errors = 0 |
| total_mb_out = 0.0 |
| for i, r in enumerate(slice_l26_small_batch.map(batches, return_exceptions=True)): |
| if isinstance(r, Exception): |
| print(f" [{i+1}/{len(batches)}] BATCH ERROR: {r}") |
| continue |
| n_done += r.get("n_done", 0) |
| n_skipped += r.get("n_skipped", 0) |
| n_errors += r.get("n_errors", 0) |
| total_mb_out += r.get("total_mb", 0.0) |
| print(f" done={n_done} skipped={n_skipped} errors={n_errors} {total_mb_out/1024:.2f} GB") |
| return {"files_total": len(paths), "n_done": n_done, "n_skipped": n_skipped, |
| "n_errors": n_errors, "total_mb_out": total_mb_out} |
|
|
|
|
| @app.function( |
| image=modal.Image.debian_slim().pip_install("huggingface_hub>=0.25"), |
| cpu=4, |
| volumes={"/vol": embeddings_l26_small_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=21600, |
| ) |
| def upload_l26_small_to_hf(repo_name: str = "mgnify-evo2-l26-small-qual", |
| private: bool = False) -> dict: |
| import os, time |
| from huggingface_hub import HfApi, login |
| token = None |
| for k, v in os.environ.items(): |
| if k.startswith("hf_") and len(k) > 30: |
| token = k; break |
| if v.startswith("hf_") and len(v) > 30: |
| token = v; break |
| if not token: token = os.environ.get("HF_TOKEN") |
| login(token=token) |
| api = HfApi() |
| user = api.whoami()["name"] |
| full_repo = f"{user}/{repo_name}" |
| api.create_repo(full_repo, repo_type="dataset", private=private, exist_ok=True) |
| print(f"[hf-push-small] uploading /vol → {full_repo} (private={private})") |
| t0 = time.time() |
| api.upload_large_folder(folder_path="/vol", repo_id=full_repo, repo_type="dataset") |
| elapsed = time.time() - t0 |
| n_files = bytes_total = 0 |
| for root, _, files in os.walk("/vol"): |
| for fname in files: |
| if fname.endswith(".npz"): |
| n_files += 1 |
| bytes_total += os.path.getsize(os.path.join(root, fname)) |
| return {"repo_url": f"https://huggingface.co/datasets/{full_repo}", |
| "n_files": n_files, "bytes_total": bytes_total, |
| "elapsed_s": elapsed, "private": private} |
|
|
|
|
| @app.local_entrypoint() |
| def push_l26_small(repo_name: str = "mgnify-evo2-l26-small-qual", |
| private: bool = False, batch_size: int = 100): |
| """Slice /embeddings_lean/small/ to layer-26 and push to HF.""" |
| print("[1/2] slicing layer 26 from small/ ...") |
| s = orchestrate_l26_small_slice.remote(batch_size=batch_size) |
| print(f" files: {s['files_total']}, done: {s['n_done']}, skipped: {s['n_skipped']}, errors: {s['n_errors']}") |
| print(f" l26 size: {s['total_mb_out']/1024:.2f} GB") |
| print("\n[2/2] pushing to HF ...") |
| u = upload_l26_small_to_hf.remote(repo_name=repo_name, private=private) |
| print(f"\n=== DONE ===") |
| print(f" repo: {u['repo_url']}") |
| print(f" files: {u['n_files']}") |
| print(f" size: {u['bytes_total']/1e9:.2f} GB") |
| print(f" elapsed: {u['elapsed_s']:.0f} s") |
|
|
|
|
| @app.local_entrypoint() |
| def run_qual_lean(batch_size: int = 8): |
| """Upload qual JSONLs + embed. Tiny job (~860 records, ~$0.30, ~2 min).""" |
| import os, time |
|
|
| base = "/home/ror25cal/MGnify/data/targeted_jsonl/qual" |
| if not os.path.isdir(base): |
| raise FileNotFoundError(f"Qual JSONLs not found at {base}; run scripts/sample_qual_jsonl.py first.") |
|
|
| jsonls = sorted(f for f in os.listdir(base) if f.endswith(".jsonl")) |
| print(f"[run-qual] uploading {len(jsonls)} JSONLs to mgnify-targeted-jsonl ...") |
| t0 = time.time() |
| with jsonl_vol.batch_upload(force=True) as batch: |
| for fname in jsonls: |
| batch.put_file(os.path.join(base, fname), f"qual/{fname}") |
| print(f" uploaded in {time.time()-t0:.0f} s") |
|
|
| rel_paths = [f"qual/{f}" for f in jsonls] |
| batches = [rel_paths[i:i + batch_size] for i in range(0, len(rel_paths), batch_size)] |
| print(f"\n[run-qual] {len(rel_paths)} JSONLs in {len(batches)} batches of up to {batch_size}") |
|
|
| t0 = time.time() |
| results = list(embed_qual_lean.map(batches, return_exceptions=True)) |
| wall = time.time() - t0 |
| ok = [r for r in results if not isinstance(r, Exception)] |
| n_done = sum(r["n_done"] for r in ok) |
| n_skipped = sum(r["n_skipped"] for r in ok) |
| total_mb = sum(r["total_mb"] for r in ok) |
| print(f"\n=== QUAL DONE ===") |
| print(f" records embedded: {n_done} (skipped {n_skipped})") |
| print(f" output size: {total_mb:.0f} MB") |
| print(f" wall clock: {wall:.0f} s") |
|
|
|
|
| @app.local_entrypoint() |
| def push_l26_vfdb(repo_name: str = "mgnify-evo2-l26-vfdb-virulence", |
| private: bool = False, batch_size: int = 100): |
| """Slice layer-26 from /embeddings_lean/vfdb/ then push to HF Datasets.""" |
| print("[1/2] slicing layer 26 from VFDB lean embeddings ...") |
| s = orchestrate_l26_vfdb_slice.remote(batch_size=batch_size) |
| print(f"\n files total: {s['files_total']}") |
| print(f" done: {s['n_done']}") |
| print(f" skipped: {s['n_skipped']}") |
| print(f" errors: {s['n_errors']}") |
| print(f" l26 vol size: {s['total_mb_out']/1024:.2f} GB") |
|
|
| print("\n[2/2] pushing VFDB layer-26 volume to HF Datasets ...") |
| u = upload_l26_vfdb_to_hf.remote(repo_name=repo_name, private=private) |
| print(f"\n=== UPLOADED ===") |
| print(f" repo: {u['repo_url']}") |
| print(f" files: {u['n_files']}") |
| print(f" size: {u['bytes_total']/1e9:.2f} GB") |
| print(f" elapsed: {u['elapsed_s']:.0f} s ({u['bytes_total']/1e6/max(u['elapsed_s'],1):.1f} MB/s)") |
|
|