| """ |
| Evo 2 layer-26 extraction pipeline used to produce the layer-26 npz files |
| shared on HuggingFace as `JG1310/mgnify-evo2-l26-full` (and the 5-layer source |
| on Modal volume `mgnify-embeddings-lean`). |
| |
| This file is written for someone debugging SAE reconstruction error to verify: |
| 1. Which model variant we ran — `evo2_7b_262k`, NOT `evo2_7b`. |
| 2. Which hook we read from — whole-block output of block 26 |
| (`blocks-26`), NOT `blocks.26.mlp.l3`. |
| 3. How activations were stored — bf16 bit-pattern in uint16 numpy array. |
| 4. What's inside each .npz file — schema documented in `LOADER_EXAMPLE`. |
| 5. Reference SAE encode (BatchTopK) — pattern follows Arc's official notebook |
| at notebooks/sparse_autoencoder/ |
| sparse_autoencoder.ipynb. |
| |
| If reconstruction error is bad on the receiver's side but the saved activations |
| match the residual stream produced by Arc's own example notebook on the same |
| input, the bug is in their SAE-encode/decode code (most common: missing |
| BatchTopK normalization, wrong dtype on matmul, wrong W vs W.T on decode). |
| |
| A reproducible smoke test is provided at the bottom: run on Modal with |
| modal run evo2_layer26_extraction.py::smoke_test |
| """ |
| import os |
| import json |
| import time |
|
|
| import modal |
|
|
|
|
| |
| |
| |
| MODEL_VARIANT = "evo2_7b_262k" |
| |
| TARGET_LAYER = "blocks-26" |
| |
| |
| |
| HIDDEN = 4096 |
| SAE_REPO = "Goodfire/Evo-2-Layer-26-Mixed" |
| SAE_FILE = "sae-layer26-mixed-expansion_8-k_64.pt" |
| SAE_K = 64 |
| |
|
|
|
|
| |
| |
| |
| image = ( |
| modal.Image.from_registry( |
| "nvcr.io/nvidia/pytorch:25.04-py3", |
| add_python=None, |
| ) |
| .apt_install("git", "python3-pip", "python3-tomli") |
| .pip_install("evo2") |
| ) |
|
|
| app = modal.App("evo2-layer26-extraction-share") |
| weights_vol = modal.Volume.from_name("evo2-7b-weights", create_if_missing=True) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def build_module_dict(model): |
| module_dict = {} |
| def recurse(m, prefix=""): |
| for name, child in m.named_children(): |
| module_dict[prefix + name] = child |
| recurse(child, prefix + name + "-") |
| recurse(model) |
| return module_dict |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={"/root/.cache/huggingface": weights_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=3600, |
| ) |
| def extract_layer26_for_sequence(sequence: str, region_metadata: dict) -> dict: |
| """ |
| Run Evo 2 forward on `sequence`, capture layer-26 residual stream, return |
| it as a bf16-as-uint16 numpy bit-pattern plus the metadata in JSON form. |
| |
| sequence: DNA string, forward strand (e.g. "ATGAA..."). |
| We do not reverse-complement minus-strand genes — we feed |
| the genomic forward strand as-is. Goodfire's reference |
| notebook also feeds raw forward-strand sequences. |
| region_metadata: arbitrary dict — locus_tag, gene coords, label class, |
| etc. — passed through into the saved .npz so each file |
| is self-describing. |
| """ |
| import numpy as np |
| import torch |
| from evo2 import Evo2 |
|
|
| |
| evo2 = Evo2(MODEL_VARIANT) |
| device = next(evo2.model.parameters()).device |
| module_dict = build_module_dict(evo2.model) |
| if TARGET_LAYER not in module_dict: |
| raise RuntimeError(f"hook target {TARGET_LAYER} not found in module tree") |
| target_module = module_dict[TARGET_LAYER] |
|
|
| |
| cache: dict = {} |
| def hook_fn(_module, _inp, out): |
| |
| |
| acts = out[0] if isinstance(out, tuple) else out |
| cache["acts"] = acts.detach() |
| handle = target_module.register_forward_hook(hook_fn) |
|
|
| try: |
| |
| |
| |
| input_ids = torch.tensor( |
| evo2.tokenizer.tokenize(sequence), |
| dtype=torch.long, |
| ).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| evo2.model(input_ids) |
| |
|
|
| acts_bf16 = cache["acts"][0] |
| seq_len, hidden = acts_bf16.shape |
| assert hidden == HIDDEN, f"unexpected hidden dim {hidden}" |
| finally: |
| handle.remove() |
| cache.clear() |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| |
| acts_uint16_np = acts_bf16.to(torch.bfloat16).view(torch.uint16).cpu().numpy() |
|
|
| return { |
| "layer26_activations_bf16": acts_uint16_np, |
| "layer26_dtype": "bfloat16", |
| "source_layer_index": 26, |
| "source_layer_name": TARGET_LAYER, |
| "seq_len": int(seq_len), |
| "hidden_size": int(hidden), |
| "model_name": MODEL_VARIANT, |
| "metadata_json": json.dumps(region_metadata), |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| LOADER_EXAMPLE = ''' |
| import numpy as np |
| import json |
| import torch |
| |
| d = np.load("AMR/MGYG.../REGION_AMR.npz", allow_pickle=False) |
| |
| # Schema (every shared file has these keys): |
| # layer26_activations_bf16 uint16 array, shape [seq_len, 4096] |
| # (bit-pattern of bf16 stored as uint16) |
| # layer26_dtype literal string "bfloat16" |
| # source_layer_index int 26 |
| # source_layer_name literal string "blocks-26" |
| # seq_len, hidden_size ints (matches array shape) |
| # model_name literal string "evo2_7b_262k" |
| # metadata_json JSON-encoded dict with locus_tag, gene_symbol, |
| # label_class, label_subclass, gene_start/end, |
| # paired_with, etc. |
| |
| # Decode bit-pattern to bf16, then to fp32 for downstream math: |
| acts_bf16 = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16) |
| acts_fp32 = acts_bf16.float() # shape [seq_len, 4096] |
| |
| # Pull the per-region metadata: |
| meta = json.loads(str(d["metadata_json"])) |
| print(meta["gene_symbol"], meta["label_class"], meta["label_subclass"]) |
| ''' |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def reference_encode_and_reconstruct(acts_fp32, sae_state_dict, K=SAE_K): |
| """Reference SAE encode → BatchTopK → decode. |
| |
| acts_fp32: [seq_len, 4096] activations (fp32 or bf16) |
| sae_state_dict: loaded from `Goodfire/Evo-2-Layer-26-Mixed` |
| via huggingface_hub.hf_hub_download |
| K: BatchTopK budget per token (default 64) |
| |
| Returns: (sparse_features, reconstructed_acts) |
| """ |
| import torch |
|
|
| |
| |
| sae = {k.replace("_orig_mod.", "").replace("module.", ""): v |
| for k, v in sae_state_dict.items()} |
| W = sae["W"] |
| b_enc = sae["b_enc"] |
| b_dec = sae.get("b_dec", torch.zeros(W.shape[0])) |
|
|
| |
| dtype = acts_fp32.dtype |
| device = acts_fp32.device |
| W = W.to(device=device, dtype=dtype) |
| b_enc = b_enc.to(device=device, dtype=dtype) |
| b_dec = b_dec.to(device=device, dtype=dtype) |
|
|
| |
| pre = torch.relu(acts_fp32 @ W + b_enc) |
|
|
| |
| seq_len, d_sae = pre.shape |
| flat = pre.flatten() |
| numel = K * seq_len |
| top = torch.topk(flat, numel, dim=-1) |
| sparse_flat = torch.zeros_like(flat).scatter(-1, top.indices, top.values) |
| features = sparse_flat.reshape(pre.shape) |
|
|
| |
| reconstructed = features @ W.T + b_dec |
|
|
| return features, reconstructed |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| @app.function( |
| image=image, |
| gpu="H100", |
| volumes={"/root/.cache/huggingface": weights_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=1800, |
| ) |
| def smoke_test(): |
| """Forward pass on a 1 kb random-ish DNA string, capture layer 26, run |
| SAE encode-decode, report reconstruction stats.""" |
| import numpy as np |
| import torch |
| from evo2 import Evo2 |
| from huggingface_hub import hf_hub_download |
|
|
| |
| seq = "ATGAACAACGTACTGAGCGAATTCAGCAATGGCAATCGGGCTAGCTAGCTAGCTGCATGCATGCATGCATGCATGCATGCATGCAT" * 12 |
| seq = seq[:1000] |
| print(f"smoke_test sequence length: {len(seq)} bp") |
|
|
| evo2 = Evo2(MODEL_VARIANT) |
| device = next(evo2.model.parameters()).device |
| module_dict = build_module_dict(evo2.model) |
| target_module = module_dict[TARGET_LAYER] |
|
|
| cache = {} |
| def hook(_, __, out): |
| cache["acts"] = (out[0] if isinstance(out, tuple) else out).detach() |
| handle = target_module.register_forward_hook(hook) |
| try: |
| input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| evo2.model(input_ids) |
| finally: |
| handle.remove() |
| acts = cache["acts"][0] |
| print(f"layer-26 activations: shape={tuple(acts.shape)} dtype={acts.dtype} " |
| f"abs_max={acts.abs().max().item():.2f} std={acts.float().std().item():.4f}") |
|
|
| |
| sae_path = hf_hub_download(repo_id=SAE_REPO, filename=SAE_FILE) |
| sae_sd = torch.load(sae_path, map_location=device, weights_only=True) |
|
|
| features, recon = reference_encode_and_reconstruct(acts.float(), sae_sd, K=SAE_K) |
|
|
| |
| orig = acts.float() |
| err = orig - recon |
| mse = (err ** 2).mean().item() |
| var = orig.var().item() |
| explained_variance = 1.0 - mse / max(var, 1e-9) |
| cosine_per_token = torch.nn.functional.cosine_similarity(orig, recon, dim=1).mean().item() |
| sparsity = (features != 0).float().mean().item() |
| print(f"\nSAE reconstruction:") |
| print(f" MSE: {mse:.5f}") |
| print(f" variance: {var:.5f}") |
| print(f" explained variance: {explained_variance:.4f} (closer to 1.0 is better)") |
| print(f" mean per-token cosine: {cosine_per_token:.4f} (closer to 1.0 is better)") |
| print(f" feature sparsity: {sparsity:.4f} (k/d_sae = {SAE_K/32768:.4f})") |
| return { |
| "mse": mse, |
| "var": var, |
| "explained_variance": explained_variance, |
| "cosine": cosine_per_token, |
| "sparsity": sparsity, |
| } |
|
|
|
|
| @app.local_entrypoint() |
| def main(): |
| """Run the smoke test and dump the reconstruction stats.""" |
| r = smoke_test.remote() |
| print(json.dumps(r, indent=2)) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| ORIGINAL_PIPELINE_NOTES = """ |
| Source data: targeted JSONL files extracted with scripts/extract_targeted.py |
| Each JSONL line is one record. Fields: |
| sequence DNA, forward strand, gene + 2 kb upstream + 2 kb downstream flank |
| mag_id, locus_tag Prodigal IDs from MGnify master GFF |
| region_id f"{locus_tag}_{label}" — unique per record |
| is_positive True for AMR/STRESS/VIRULENCE positives, |
| False for matched negatives |
| label "AMR" | "STRESS" | "VIRULENCE" | "negative" |
| label_class AMRFinderPlus class (e.g. "BETA-LACTAM", "MACROLIDE") |
| label_subclass AMRFinderPlus subclass |
| gene_symbol e.g. "blaOXA", "catA" |
| pct_identity_to_ref AMRFinderPlus protein identity to reference seq |
| (proxy for memorisation: < 80% suggests novel allele) |
| paired_with locus_tag of the matched partner (positive ↔ negative) |
| gene_start, gene_end, strand, contig, ext_start, ext_end |
| gc_content, cds_in_mobilome, negative_pool_fallback |
| |
| For each record we ran `extract_layer26_for_sequence(record["sequence"], record)` |
| and saved the result to {label}/{mag_id}/{region_id}.npz. |
| |
| Layout on the HF dataset `JG1310/mgnify-evo2-l26-full`: |
| AMR/{mag_id}/{region_id}.npz — AMR positive |
| STRESS/{mag_id}/{region_id}.npz — stress-resistance positive |
| VIRULENCE/{mag_id}/{region_id}.npz — virulence positive |
| MISC/{mag_id}/{region_id}.npz — matched-CDS negatives |
| (paired_with field links to the |
| positive in AMR/, STRESS/, etc.) |
| """ |
|
|