mgnify-evo2-probes / code /share /evo2_layer26_extraction.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""
Evo 2 layer-26 extraction pipeline used to produce the layer-26 npz files
shared on HuggingFace as `JG1310/mgnify-evo2-l26-full` (and the 5-layer source
on Modal volume `mgnify-embeddings-lean`).
This file is written for someone debugging SAE reconstruction error to verify:
1. Which model variant we ran — `evo2_7b_262k`, NOT `evo2_7b`.
2. Which hook we read from — whole-block output of block 26
(`blocks-26`), NOT `blocks.26.mlp.l3`.
3. How activations were stored — bf16 bit-pattern in uint16 numpy array.
4. What's inside each .npz file — schema documented in `LOADER_EXAMPLE`.
5. Reference SAE encode (BatchTopK) — pattern follows Arc's official notebook
at notebooks/sparse_autoencoder/
sparse_autoencoder.ipynb.
If reconstruction error is bad on the receiver's side but the saved activations
match the residual stream produced by Arc's own example notebook on the same
input, the bug is in their SAE-encode/decode code (most common: missing
BatchTopK normalization, wrong dtype on matmul, wrong W vs W.T on decode).
A reproducible smoke test is provided at the bottom: run on Modal with
modal run evo2_layer26_extraction.py::smoke_test
"""
import os
import json
import time
import modal
# =============================================================================
# Constants
# =============================================================================
MODEL_VARIANT = "evo2_7b_262k" # 262k-context variant — Goodfire's SAE was
# trained against this, not the vanilla 7b.
TARGET_LAYER = "blocks-26" # whole-block output (residual stream after
# block 26). NOT blocks-26-mlp-l3 — that
# would be a sub-module's output and would
# give different activations.
HIDDEN = 4096 # Evo 2 7b residual stream dim.
SAE_REPO = "Goodfire/Evo-2-Layer-26-Mixed"
SAE_FILE = "sae-layer26-mixed-expansion_8-k_64.pt"
SAE_K = 64 # BatchTopK budget per token-batch
# (k=64, expansion=8 ⇒ d_sae = 32768).
# =============================================================================
# Modal image (Arc Institute's official Evo 2 Dockerfile, translated to Modal)
# =============================================================================
image = (
modal.Image.from_registry(
"nvcr.io/nvidia/pytorch:25.04-py3",
add_python=None, # base image already has Python 3.13
)
.apt_install("git", "python3-pip", "python3-tomli")
.pip_install("evo2") # pulls flash-attn + vortex-model + huggingface_hub
)
app = modal.App("evo2-layer26-extraction-share")
weights_vol = modal.Volume.from_name("evo2-7b-weights", create_if_missing=True)
# =============================================================================
# Helper: walk the StripedHyena module tree to find a hook target by name.
# StripedHyena's nesting structure means `blocks.26` is reached via
# `evo2.model.blocks[26]`, but its child names are ('mixer', 'mlp', etc.).
# We use `named_children()` and join with '-' so that:
# blocks-26 = block 26's container forward output
# blocks-26-mlp-l3 = block 26's MLP last-layer linear output
# =============================================================================
def build_module_dict(model):
module_dict = {}
def recurse(m, prefix=""):
for name, child in m.named_children():
module_dict[prefix + name] = child
recurse(child, prefix + name + "-")
recurse(model)
return module_dict
# =============================================================================
# The actual extraction function. This is what wrote each per-region npz.
#
# Important details to verify against your own pipeline:
# - Forward pass receives the full sequence (gene + 2 kb upstream + 2 kb
# downstream flank) so causal Hyena convolution gets context.
# - Hook fires on `blocks-26.forward` and we capture `out[0]` if the output
# is a tuple, else `out`. For StripedHyena blocks the first tuple element
# is the residual-stream hidden state passed to block 27 — this is what
# Goodfire's SAE was trained on.
# - The captured tensor is bf16 on GPU. We keep it in bf16 and reinterpret
# the bit-pattern as uint16 because numpy does not support bf16 natively.
# This is a *bit-exact* reinterpretation, NOT a precision-losing cast —
# decode with `torch.from_numpy(arr).view(torch.bfloat16)`.
# - We do NOT compress (no gzip) — random-looking bf16 floats compress poorly
# and gzip was the dominant cost during a previous failed run.
# =============================================================================
@app.function(
image=image,
gpu="H100",
volumes={"/root/.cache/huggingface": weights_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=3600,
)
def extract_layer26_for_sequence(sequence: str, region_metadata: dict) -> dict:
"""
Run Evo 2 forward on `sequence`, capture layer-26 residual stream, return
it as a bf16-as-uint16 numpy bit-pattern plus the metadata in JSON form.
sequence: DNA string, forward strand (e.g. "ATGAA...").
We do not reverse-complement minus-strand genes — we feed
the genomic forward strand as-is. Goodfire's reference
notebook also feeds raw forward-strand sequences.
region_metadata: arbitrary dict — locus_tag, gene coords, label class,
etc. — passed through into the saved .npz so each file
is self-describing.
"""
import numpy as np
import torch
from evo2 import Evo2
# ----- load the 262k-context Evo 2 7B variant (one-time per container) ---
evo2 = Evo2(MODEL_VARIANT)
device = next(evo2.model.parameters()).device
module_dict = build_module_dict(evo2.model)
if TARGET_LAYER not in module_dict:
raise RuntimeError(f"hook target {TARGET_LAYER} not found in module tree")
target_module = module_dict[TARGET_LAYER]
# ----- register the hook -------------------------------------------------
cache: dict = {}
def hook_fn(_module, _inp, out):
# StripedHyena blocks return a tuple where index 0 is the residual-
# stream hidden state. Some sub-modules return just a tensor.
acts = out[0] if isinstance(out, tuple) else out
cache["acts"] = acts.detach() # detach so we don't keep autograd graph
handle = target_module.register_forward_hook(hook_fn)
try:
# ----- forward pass --------------------------------------------------
# Tokenizer: each nucleotide gets one token id (Evo 2's tokenizer is
# byte-level on ACGTN). Sequence length = len(sequence).
input_ids = torch.tensor(
evo2.tokenizer.tokenize(sequence),
dtype=torch.long,
).unsqueeze(0).to(device) # add batch dim, move to GPU
with torch.no_grad():
evo2.model(input_ids)
# No need for output logits — we only care about the cached activation.
acts_bf16 = cache["acts"][0] # squeeze batch dim → [seq_len, HIDDEN]
seq_len, hidden = acts_bf16.shape
assert hidden == HIDDEN, f"unexpected hidden dim {hidden}"
finally:
handle.remove()
cache.clear()
torch.cuda.empty_cache()
# ----- bf16 → uint16 bit-pattern (lossless) -----------------------------
# `view(torch.uint16)` is a zero-copy reinterpretation of the same memory:
# the bit-pattern of a bf16 float is read as the bit-pattern of a uint16.
# No precision loss. Decode on the receiving side with the inverse.
acts_uint16_np = acts_bf16.to(torch.bfloat16).view(torch.uint16).cpu().numpy()
return {
"layer26_activations_bf16": acts_uint16_np, # uint16 [seq_len, 4096]
"layer26_dtype": "bfloat16", # marker for decode
"source_layer_index": 26,
"source_layer_name": TARGET_LAYER,
"seq_len": int(seq_len),
"hidden_size": int(hidden),
"model_name": MODEL_VARIANT,
"metadata_json": json.dumps(region_metadata),
}
# =============================================================================
# Reference loader — exactly how to read one of our shared npz files back.
# This is what receivers should do; if they don't get the right shape/dtype,
# the bug is here, not upstream.
# =============================================================================
LOADER_EXAMPLE = '''
import numpy as np
import json
import torch
d = np.load("AMR/MGYG.../REGION_AMR.npz", allow_pickle=False)
# Schema (every shared file has these keys):
# layer26_activations_bf16 uint16 array, shape [seq_len, 4096]
# (bit-pattern of bf16 stored as uint16)
# layer26_dtype literal string "bfloat16"
# source_layer_index int 26
# source_layer_name literal string "blocks-26"
# seq_len, hidden_size ints (matches array shape)
# model_name literal string "evo2_7b_262k"
# metadata_json JSON-encoded dict with locus_tag, gene_symbol,
# label_class, label_subclass, gene_start/end,
# paired_with, etc.
# Decode bit-pattern to bf16, then to fp32 for downstream math:
acts_bf16 = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16)
acts_fp32 = acts_bf16.float() # shape [seq_len, 4096]
# Pull the per-region metadata:
meta = json.loads(str(d["metadata_json"]))
print(meta["gene_symbol"], meta["label_class"], meta["label_subclass"])
'''
# =============================================================================
# Reference SAE encode-and-decode that produced sane reconstruction in our
# CRISPR sanity test (5/57 of Goodfire's published features fired strongly on
# E. coli K12 CRISPR arrays). Use this to compare your own SAE handling.
#
# THE THREE PLACES WHERE PEOPLE GET THIS WRONG:
#
# 1. dtype: cast both `W_enc`/`b_enc` AND `acts` to the SAME dtype (bf16 OR
# fp32, but consistent) before the matmul. Mixed-dtype matmuls silently
# downcast in unexpected ways on some GPU paths.
#
# 2. BatchTopK is *batch-wide*, not per-token. The top-K is computed across
# the FLATTENED (seq_len * d_sae) tensor with k = K * seq_len, NOT
# `topk(k=64)` per token. Per-token topk would be ~seq_len× sparser.
#
# 3. Reconstruction uses `W.T` (the transpose) not `W`. Goodfire's SAE has
# tied encoder/decoder weights, so a single `W` matrix in the state dict.
# forward = ReLU(acts @ W + b_enc); backward = features @ W.T + b_dec.
# =============================================================================
def reference_encode_and_reconstruct(acts_fp32, sae_state_dict, K=SAE_K):
"""Reference SAE encode → BatchTopK → decode.
acts_fp32: [seq_len, 4096] activations (fp32 or bf16)
sae_state_dict: loaded from `Goodfire/Evo-2-Layer-26-Mixed`
via huggingface_hub.hf_hub_download
K: BatchTopK budget per token (default 64)
Returns: (sparse_features, reconstructed_acts)
"""
import torch
# The official Goodfire checkpoint was saved with torch.compile + DDP
# prefixes — strip them when loading:
sae = {k.replace("_orig_mod.", "").replace("module.", ""): v
for k, v in sae_state_dict.items()}
W = sae["W"] # [4096, 32768]
b_enc = sae["b_enc"] # [32768]
b_dec = sae.get("b_dec", torch.zeros(W.shape[0])) # [4096]; some checkpoints omit
# Match dtypes carefully (see "place 1" above):
dtype = acts_fp32.dtype
device = acts_fp32.device
W = W.to(device=device, dtype=dtype)
b_enc = b_enc.to(device=device, dtype=dtype)
b_dec = b_dec.to(device=device, dtype=dtype)
# ----- encode (same as Arc's notebook) -----------------------------------
pre = torch.relu(acts_fp32 @ W + b_enc) # [seq_len, 32768]
# BatchTopK across the WHOLE [seq_len * d_sae] flattened tensor (place 2):
seq_len, d_sae = pre.shape
flat = pre.flatten()
numel = K * seq_len # total non-zero budget
top = torch.topk(flat, numel, dim=-1)
sparse_flat = torch.zeros_like(flat).scatter(-1, top.indices, top.values)
features = sparse_flat.reshape(pre.shape) # [seq_len, 32768], sparse
# ----- decode using W.T (place 3) ----------------------------------------
reconstructed = features @ W.T + b_dec # [seq_len, 4096]
return features, reconstructed
# =============================================================================
# Standalone smoke test you can run to verify the full pipeline end-to-end
# on a known input. If this gives weird reconstruction, the issue is upstream;
# if reconstruction is clean here but bad in your pipeline, it's downstream.
#
# Usage:
# modal run evo2_layer26_extraction.py::smoke_test
# =============================================================================
@app.function(
image=image,
gpu="H100",
volumes={"/root/.cache/huggingface": weights_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=1800,
)
def smoke_test():
"""Forward pass on a 1 kb random-ish DNA string, capture layer 26, run
SAE encode-decode, report reconstruction stats."""
import numpy as np
import torch
from evo2 import Evo2
from huggingface_hub import hf_hub_download
# 1 kb random-looking sequence — same scale as Goodfire's chr17 example
seq = "ATGAACAACGTACTGAGCGAATTCAGCAATGGCAATCGGGCTAGCTAGCTAGCTGCATGCATGCATGCATGCATGCATGCATGCAT" * 12
seq = seq[:1000]
print(f"smoke_test sequence length: {len(seq)} bp")
evo2 = Evo2(MODEL_VARIANT)
device = next(evo2.model.parameters()).device
module_dict = build_module_dict(evo2.model)
target_module = module_dict[TARGET_LAYER]
cache = {}
def hook(_, __, out):
cache["acts"] = (out[0] if isinstance(out, tuple) else out).detach()
handle = target_module.register_forward_hook(hook)
try:
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
finally:
handle.remove()
acts = cache["acts"][0] # [seq_len, 4096] bf16
print(f"layer-26 activations: shape={tuple(acts.shape)} dtype={acts.dtype} "
f"abs_max={acts.abs().max().item():.2f} std={acts.float().std().item():.4f}")
# Load SAE and run encode + decode
sae_path = hf_hub_download(repo_id=SAE_REPO, filename=SAE_FILE)
sae_sd = torch.load(sae_path, map_location=device, weights_only=True)
features, recon = reference_encode_and_reconstruct(acts.float(), sae_sd, K=SAE_K)
# ---- reconstruction metrics ---------------------------------------------
orig = acts.float()
err = orig - recon
mse = (err ** 2).mean().item()
var = orig.var().item()
explained_variance = 1.0 - mse / max(var, 1e-9)
cosine_per_token = torch.nn.functional.cosine_similarity(orig, recon, dim=1).mean().item()
sparsity = (features != 0).float().mean().item()
print(f"\nSAE reconstruction:")
print(f" MSE: {mse:.5f}")
print(f" variance: {var:.5f}")
print(f" explained variance: {explained_variance:.4f} (closer to 1.0 is better)")
print(f" mean per-token cosine: {cosine_per_token:.4f} (closer to 1.0 is better)")
print(f" feature sparsity: {sparsity:.4f} (k/d_sae = {SAE_K/32768:.4f})")
return {
"mse": mse,
"var": var,
"explained_variance": explained_variance,
"cosine": cosine_per_token,
"sparsity": sparsity,
}
@app.local_entrypoint()
def main():
"""Run the smoke test and dump the reconstruction stats."""
r = smoke_test.remote()
print(json.dumps(r, indent=2))
# =============================================================================
# Quick reference: the original orchestrator used to extract every region.
# Each region's record was a dict with keys (`sequence`, `mag_id`,
# `locus_tag`, `region_id`, `is_positive`, `label`, `label_class`, etc.) — the
# same dict is JSON-encoded into `metadata_json` in each saved npz.
# =============================================================================
ORIGINAL_PIPELINE_NOTES = """
Source data: targeted JSONL files extracted with scripts/extract_targeted.py
Each JSONL line is one record. Fields:
sequence DNA, forward strand, gene + 2 kb upstream + 2 kb downstream flank
mag_id, locus_tag Prodigal IDs from MGnify master GFF
region_id f"{locus_tag}_{label}" — unique per record
is_positive True for AMR/STRESS/VIRULENCE positives,
False for matched negatives
label "AMR" | "STRESS" | "VIRULENCE" | "negative"
label_class AMRFinderPlus class (e.g. "BETA-LACTAM", "MACROLIDE")
label_subclass AMRFinderPlus subclass
gene_symbol e.g. "blaOXA", "catA"
pct_identity_to_ref AMRFinderPlus protein identity to reference seq
(proxy for memorisation: < 80% suggests novel allele)
paired_with locus_tag of the matched partner (positive ↔ negative)
gene_start, gene_end, strand, contig, ext_start, ext_end
gc_content, cds_in_mobilome, negative_pool_fallback
For each record we ran `extract_layer26_for_sequence(record["sequence"], record)`
and saved the result to {label}/{mag_id}/{region_id}.npz.
Layout on the HF dataset `JG1310/mgnify-evo2-l26-full`:
AMR/{mag_id}/{region_id}.npz — AMR positive
STRESS/{mag_id}/{region_id}.npz — stress-resistance positive
VIRULENCE/{mag_id}/{region_id}.npz — virulence positive
MISC/{mag_id}/{region_id}.npz — matched-CDS negatives
(paired_with field links to the
positive in AMR/, STRESS/, etc.)
"""