""" Evo 2 layer-26 extraction pipeline used to produce the layer-26 npz files shared on HuggingFace as `JG1310/mgnify-evo2-l26-full` (and the 5-layer source on Modal volume `mgnify-embeddings-lean`). This file is written for someone debugging SAE reconstruction error to verify: 1. Which model variant we ran — `evo2_7b_262k`, NOT `evo2_7b`. 2. Which hook we read from — whole-block output of block 26 (`blocks-26`), NOT `blocks.26.mlp.l3`. 3. How activations were stored — bf16 bit-pattern in uint16 numpy array. 4. What's inside each .npz file — schema documented in `LOADER_EXAMPLE`. 5. Reference SAE encode (BatchTopK) — pattern follows Arc's official notebook at notebooks/sparse_autoencoder/ sparse_autoencoder.ipynb. If reconstruction error is bad on the receiver's side but the saved activations match the residual stream produced by Arc's own example notebook on the same input, the bug is in their SAE-encode/decode code (most common: missing BatchTopK normalization, wrong dtype on matmul, wrong W vs W.T on decode). A reproducible smoke test is provided at the bottom: run on Modal with modal run evo2_layer26_extraction.py::smoke_test """ import os import json import time import modal # ============================================================================= # Constants # ============================================================================= MODEL_VARIANT = "evo2_7b_262k" # 262k-context variant — Goodfire's SAE was # trained against this, not the vanilla 7b. TARGET_LAYER = "blocks-26" # whole-block output (residual stream after # block 26). NOT blocks-26-mlp-l3 — that # would be a sub-module's output and would # give different activations. HIDDEN = 4096 # Evo 2 7b residual stream dim. SAE_REPO = "Goodfire/Evo-2-Layer-26-Mixed" SAE_FILE = "sae-layer26-mixed-expansion_8-k_64.pt" SAE_K = 64 # BatchTopK budget per token-batch # (k=64, expansion=8 ⇒ d_sae = 32768). # ============================================================================= # Modal image (Arc Institute's official Evo 2 Dockerfile, translated to Modal) # ============================================================================= image = ( modal.Image.from_registry( "nvcr.io/nvidia/pytorch:25.04-py3", add_python=None, # base image already has Python 3.13 ) .apt_install("git", "python3-pip", "python3-tomli") .pip_install("evo2") # pulls flash-attn + vortex-model + huggingface_hub ) app = modal.App("evo2-layer26-extraction-share") weights_vol = modal.Volume.from_name("evo2-7b-weights", create_if_missing=True) # ============================================================================= # Helper: walk the StripedHyena module tree to find a hook target by name. # StripedHyena's nesting structure means `blocks.26` is reached via # `evo2.model.blocks[26]`, but its child names are ('mixer', 'mlp', etc.). # We use `named_children()` and join with '-' so that: # blocks-26 = block 26's container forward output # blocks-26-mlp-l3 = block 26's MLP last-layer linear output # ============================================================================= def build_module_dict(model): module_dict = {} def recurse(m, prefix=""): for name, child in m.named_children(): module_dict[prefix + name] = child recurse(child, prefix + name + "-") recurse(model) return module_dict # ============================================================================= # The actual extraction function. This is what wrote each per-region npz. # # Important details to verify against your own pipeline: # - Forward pass receives the full sequence (gene + 2 kb upstream + 2 kb # downstream flank) so causal Hyena convolution gets context. # - Hook fires on `blocks-26.forward` and we capture `out[0]` if the output # is a tuple, else `out`. For StripedHyena blocks the first tuple element # is the residual-stream hidden state passed to block 27 — this is what # Goodfire's SAE was trained on. # - The captured tensor is bf16 on GPU. We keep it in bf16 and reinterpret # the bit-pattern as uint16 because numpy does not support bf16 natively. # This is a *bit-exact* reinterpretation, NOT a precision-losing cast — # decode with `torch.from_numpy(arr).view(torch.bfloat16)`. # - We do NOT compress (no gzip) — random-looking bf16 floats compress poorly # and gzip was the dominant cost during a previous failed run. # ============================================================================= @app.function( image=image, gpu="H100", volumes={"/root/.cache/huggingface": weights_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=3600, ) def extract_layer26_for_sequence(sequence: str, region_metadata: dict) -> dict: """ Run Evo 2 forward on `sequence`, capture layer-26 residual stream, return it as a bf16-as-uint16 numpy bit-pattern plus the metadata in JSON form. sequence: DNA string, forward strand (e.g. "ATGAA..."). We do not reverse-complement minus-strand genes — we feed the genomic forward strand as-is. Goodfire's reference notebook also feeds raw forward-strand sequences. region_metadata: arbitrary dict — locus_tag, gene coords, label class, etc. — passed through into the saved .npz so each file is self-describing. """ import numpy as np import torch from evo2 import Evo2 # ----- load the 262k-context Evo 2 7B variant (one-time per container) --- evo2 = Evo2(MODEL_VARIANT) device = next(evo2.model.parameters()).device module_dict = build_module_dict(evo2.model) if TARGET_LAYER not in module_dict: raise RuntimeError(f"hook target {TARGET_LAYER} not found in module tree") target_module = module_dict[TARGET_LAYER] # ----- register the hook ------------------------------------------------- cache: dict = {} def hook_fn(_module, _inp, out): # StripedHyena blocks return a tuple where index 0 is the residual- # stream hidden state. Some sub-modules return just a tensor. acts = out[0] if isinstance(out, tuple) else out cache["acts"] = acts.detach() # detach so we don't keep autograd graph handle = target_module.register_forward_hook(hook_fn) try: # ----- forward pass -------------------------------------------------- # Tokenizer: each nucleotide gets one token id (Evo 2's tokenizer is # byte-level on ACGTN). Sequence length = len(sequence). input_ids = torch.tensor( evo2.tokenizer.tokenize(sequence), dtype=torch.long, ).unsqueeze(0).to(device) # add batch dim, move to GPU with torch.no_grad(): evo2.model(input_ids) # No need for output logits — we only care about the cached activation. acts_bf16 = cache["acts"][0] # squeeze batch dim → [seq_len, HIDDEN] seq_len, hidden = acts_bf16.shape assert hidden == HIDDEN, f"unexpected hidden dim {hidden}" finally: handle.remove() cache.clear() torch.cuda.empty_cache() # ----- bf16 → uint16 bit-pattern (lossless) ----------------------------- # `view(torch.uint16)` is a zero-copy reinterpretation of the same memory: # the bit-pattern of a bf16 float is read as the bit-pattern of a uint16. # No precision loss. Decode on the receiving side with the inverse. acts_uint16_np = acts_bf16.to(torch.bfloat16).view(torch.uint16).cpu().numpy() return { "layer26_activations_bf16": acts_uint16_np, # uint16 [seq_len, 4096] "layer26_dtype": "bfloat16", # marker for decode "source_layer_index": 26, "source_layer_name": TARGET_LAYER, "seq_len": int(seq_len), "hidden_size": int(hidden), "model_name": MODEL_VARIANT, "metadata_json": json.dumps(region_metadata), } # ============================================================================= # Reference loader — exactly how to read one of our shared npz files back. # This is what receivers should do; if they don't get the right shape/dtype, # the bug is here, not upstream. # ============================================================================= LOADER_EXAMPLE = ''' import numpy as np import json import torch d = np.load("AMR/MGYG.../REGION_AMR.npz", allow_pickle=False) # Schema (every shared file has these keys): # layer26_activations_bf16 uint16 array, shape [seq_len, 4096] # (bit-pattern of bf16 stored as uint16) # layer26_dtype literal string "bfloat16" # source_layer_index int 26 # source_layer_name literal string "blocks-26" # seq_len, hidden_size ints (matches array shape) # model_name literal string "evo2_7b_262k" # metadata_json JSON-encoded dict with locus_tag, gene_symbol, # label_class, label_subclass, gene_start/end, # paired_with, etc. # Decode bit-pattern to bf16, then to fp32 for downstream math: acts_bf16 = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16) acts_fp32 = acts_bf16.float() # shape [seq_len, 4096] # Pull the per-region metadata: meta = json.loads(str(d["metadata_json"])) print(meta["gene_symbol"], meta["label_class"], meta["label_subclass"]) ''' # ============================================================================= # Reference SAE encode-and-decode that produced sane reconstruction in our # CRISPR sanity test (5/57 of Goodfire's published features fired strongly on # E. coli K12 CRISPR arrays). Use this to compare your own SAE handling. # # THE THREE PLACES WHERE PEOPLE GET THIS WRONG: # # 1. dtype: cast both `W_enc`/`b_enc` AND `acts` to the SAME dtype (bf16 OR # fp32, but consistent) before the matmul. Mixed-dtype matmuls silently # downcast in unexpected ways on some GPU paths. # # 2. BatchTopK is *batch-wide*, not per-token. The top-K is computed across # the FLATTENED (seq_len * d_sae) tensor with k = K * seq_len, NOT # `topk(k=64)` per token. Per-token topk would be ~seq_len× sparser. # # 3. Reconstruction uses `W.T` (the transpose) not `W`. Goodfire's SAE has # tied encoder/decoder weights, so a single `W` matrix in the state dict. # forward = ReLU(acts @ W + b_enc); backward = features @ W.T + b_dec. # ============================================================================= def reference_encode_and_reconstruct(acts_fp32, sae_state_dict, K=SAE_K): """Reference SAE encode → BatchTopK → decode. acts_fp32: [seq_len, 4096] activations (fp32 or bf16) sae_state_dict: loaded from `Goodfire/Evo-2-Layer-26-Mixed` via huggingface_hub.hf_hub_download K: BatchTopK budget per token (default 64) Returns: (sparse_features, reconstructed_acts) """ import torch # The official Goodfire checkpoint was saved with torch.compile + DDP # prefixes — strip them when loading: sae = {k.replace("_orig_mod.", "").replace("module.", ""): v for k, v in sae_state_dict.items()} W = sae["W"] # [4096, 32768] b_enc = sae["b_enc"] # [32768] b_dec = sae.get("b_dec", torch.zeros(W.shape[0])) # [4096]; some checkpoints omit # Match dtypes carefully (see "place 1" above): dtype = acts_fp32.dtype device = acts_fp32.device W = W.to(device=device, dtype=dtype) b_enc = b_enc.to(device=device, dtype=dtype) b_dec = b_dec.to(device=device, dtype=dtype) # ----- encode (same as Arc's notebook) ----------------------------------- pre = torch.relu(acts_fp32 @ W + b_enc) # [seq_len, 32768] # BatchTopK across the WHOLE [seq_len * d_sae] flattened tensor (place 2): seq_len, d_sae = pre.shape flat = pre.flatten() numel = K * seq_len # total non-zero budget top = torch.topk(flat, numel, dim=-1) sparse_flat = torch.zeros_like(flat).scatter(-1, top.indices, top.values) features = sparse_flat.reshape(pre.shape) # [seq_len, 32768], sparse # ----- decode using W.T (place 3) ---------------------------------------- reconstructed = features @ W.T + b_dec # [seq_len, 4096] return features, reconstructed # ============================================================================= # Standalone smoke test you can run to verify the full pipeline end-to-end # on a known input. If this gives weird reconstruction, the issue is upstream; # if reconstruction is clean here but bad in your pipeline, it's downstream. # # Usage: # modal run evo2_layer26_extraction.py::smoke_test # ============================================================================= @app.function( image=image, gpu="H100", volumes={"/root/.cache/huggingface": weights_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=1800, ) def smoke_test(): """Forward pass on a 1 kb random-ish DNA string, capture layer 26, run SAE encode-decode, report reconstruction stats.""" import numpy as np import torch from evo2 import Evo2 from huggingface_hub import hf_hub_download # 1 kb random-looking sequence — same scale as Goodfire's chr17 example seq = "ATGAACAACGTACTGAGCGAATTCAGCAATGGCAATCGGGCTAGCTAGCTAGCTGCATGCATGCATGCATGCATGCATGCATGCAT" * 12 seq = seq[:1000] print(f"smoke_test sequence length: {len(seq)} bp") evo2 = Evo2(MODEL_VARIANT) device = next(evo2.model.parameters()).device module_dict = build_module_dict(evo2.model) target_module = module_dict[TARGET_LAYER] cache = {} def hook(_, __, out): cache["acts"] = (out[0] if isinstance(out, tuple) else out).detach() handle = target_module.register_forward_hook(hook) try: input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) with torch.no_grad(): evo2.model(input_ids) finally: handle.remove() acts = cache["acts"][0] # [seq_len, 4096] bf16 print(f"layer-26 activations: shape={tuple(acts.shape)} dtype={acts.dtype} " f"abs_max={acts.abs().max().item():.2f} std={acts.float().std().item():.4f}") # Load SAE and run encode + decode sae_path = hf_hub_download(repo_id=SAE_REPO, filename=SAE_FILE) sae_sd = torch.load(sae_path, map_location=device, weights_only=True) features, recon = reference_encode_and_reconstruct(acts.float(), sae_sd, K=SAE_K) # ---- reconstruction metrics --------------------------------------------- orig = acts.float() err = orig - recon mse = (err ** 2).mean().item() var = orig.var().item() explained_variance = 1.0 - mse / max(var, 1e-9) cosine_per_token = torch.nn.functional.cosine_similarity(orig, recon, dim=1).mean().item() sparsity = (features != 0).float().mean().item() print(f"\nSAE reconstruction:") print(f" MSE: {mse:.5f}") print(f" variance: {var:.5f}") print(f" explained variance: {explained_variance:.4f} (closer to 1.0 is better)") print(f" mean per-token cosine: {cosine_per_token:.4f} (closer to 1.0 is better)") print(f" feature sparsity: {sparsity:.4f} (k/d_sae = {SAE_K/32768:.4f})") return { "mse": mse, "var": var, "explained_variance": explained_variance, "cosine": cosine_per_token, "sparsity": sparsity, } @app.local_entrypoint() def main(): """Run the smoke test and dump the reconstruction stats.""" r = smoke_test.remote() print(json.dumps(r, indent=2)) # ============================================================================= # Quick reference: the original orchestrator used to extract every region. # Each region's record was a dict with keys (`sequence`, `mag_id`, # `locus_tag`, `region_id`, `is_positive`, `label`, `label_class`, etc.) — the # same dict is JSON-encoded into `metadata_json` in each saved npz. # ============================================================================= ORIGINAL_PIPELINE_NOTES = """ Source data: targeted JSONL files extracted with scripts/extract_targeted.py Each JSONL line is one record. Fields: sequence DNA, forward strand, gene + 2 kb upstream + 2 kb downstream flank mag_id, locus_tag Prodigal IDs from MGnify master GFF region_id f"{locus_tag}_{label}" — unique per record is_positive True for AMR/STRESS/VIRULENCE positives, False for matched negatives label "AMR" | "STRESS" | "VIRULENCE" | "negative" label_class AMRFinderPlus class (e.g. "BETA-LACTAM", "MACROLIDE") label_subclass AMRFinderPlus subclass gene_symbol e.g. "blaOXA", "catA" pct_identity_to_ref AMRFinderPlus protein identity to reference seq (proxy for memorisation: < 80% suggests novel allele) paired_with locus_tag of the matched partner (positive ↔ negative) gene_start, gene_end, strand, contig, ext_start, ext_end gc_content, cds_in_mobilome, negative_pool_fallback For each record we ran `extract_layer26_for_sequence(record["sequence"], record)` and saved the result to {label}/{mag_id}/{region_id}.npz. Layout on the HF dataset `JG1310/mgnify-evo2-l26-full`: AMR/{mag_id}/{region_id}.npz — AMR positive STRESS/{mag_id}/{region_id}.npz — stress-resistance positive VIRULENCE/{mag_id}/{region_id}.npz — virulence positive MISC/{mag_id}/{region_id}.npz — matched-CDS negatives (paired_with field links to the positive in AMR/, STRESS/, etc.) """