File size: 18,496 Bytes
eb69de4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 | """
Evo 2 layer-26 extraction pipeline used to produce the layer-26 npz files
shared on HuggingFace as `JG1310/mgnify-evo2-l26-full` (and the 5-layer source
on Modal volume `mgnify-embeddings-lean`).
This file is written for someone debugging SAE reconstruction error to verify:
1. Which model variant we ran β `evo2_7b_262k`, NOT `evo2_7b`.
2. Which hook we read from β whole-block output of block 26
(`blocks-26`), NOT `blocks.26.mlp.l3`.
3. How activations were stored β bf16 bit-pattern in uint16 numpy array.
4. What's inside each .npz file β schema documented in `LOADER_EXAMPLE`.
5. Reference SAE encode (BatchTopK) β pattern follows Arc's official notebook
at notebooks/sparse_autoencoder/
sparse_autoencoder.ipynb.
If reconstruction error is bad on the receiver's side but the saved activations
match the residual stream produced by Arc's own example notebook on the same
input, the bug is in their SAE-encode/decode code (most common: missing
BatchTopK normalization, wrong dtype on matmul, wrong W vs W.T on decode).
A reproducible smoke test is provided at the bottom: run on Modal with
modal run evo2_layer26_extraction.py::smoke_test
"""
import os
import json
import time
import modal
# =============================================================================
# Constants
# =============================================================================
MODEL_VARIANT = "evo2_7b_262k" # 262k-context variant β Goodfire's SAE was
# trained against this, not the vanilla 7b.
TARGET_LAYER = "blocks-26" # whole-block output (residual stream after
# block 26). NOT blocks-26-mlp-l3 β that
# would be a sub-module's output and would
# give different activations.
HIDDEN = 4096 # Evo 2 7b residual stream dim.
SAE_REPO = "Goodfire/Evo-2-Layer-26-Mixed"
SAE_FILE = "sae-layer26-mixed-expansion_8-k_64.pt"
SAE_K = 64 # BatchTopK budget per token-batch
# (k=64, expansion=8 β d_sae = 32768).
# =============================================================================
# Modal image (Arc Institute's official Evo 2 Dockerfile, translated to Modal)
# =============================================================================
image = (
modal.Image.from_registry(
"nvcr.io/nvidia/pytorch:25.04-py3",
add_python=None, # base image already has Python 3.13
)
.apt_install("git", "python3-pip", "python3-tomli")
.pip_install("evo2") # pulls flash-attn + vortex-model + huggingface_hub
)
app = modal.App("evo2-layer26-extraction-share")
weights_vol = modal.Volume.from_name("evo2-7b-weights", create_if_missing=True)
# =============================================================================
# Helper: walk the StripedHyena module tree to find a hook target by name.
# StripedHyena's nesting structure means `blocks.26` is reached via
# `evo2.model.blocks[26]`, but its child names are ('mixer', 'mlp', etc.).
# We use `named_children()` and join with '-' so that:
# blocks-26 = block 26's container forward output
# blocks-26-mlp-l3 = block 26's MLP last-layer linear output
# =============================================================================
def build_module_dict(model):
module_dict = {}
def recurse(m, prefix=""):
for name, child in m.named_children():
module_dict[prefix + name] = child
recurse(child, prefix + name + "-")
recurse(model)
return module_dict
# =============================================================================
# The actual extraction function. This is what wrote each per-region npz.
#
# Important details to verify against your own pipeline:
# - Forward pass receives the full sequence (gene + 2 kb upstream + 2 kb
# downstream flank) so causal Hyena convolution gets context.
# - Hook fires on `blocks-26.forward` and we capture `out[0]` if the output
# is a tuple, else `out`. For StripedHyena blocks the first tuple element
# is the residual-stream hidden state passed to block 27 β this is what
# Goodfire's SAE was trained on.
# - The captured tensor is bf16 on GPU. We keep it in bf16 and reinterpret
# the bit-pattern as uint16 because numpy does not support bf16 natively.
# This is a *bit-exact* reinterpretation, NOT a precision-losing cast β
# decode with `torch.from_numpy(arr).view(torch.bfloat16)`.
# - We do NOT compress (no gzip) β random-looking bf16 floats compress poorly
# and gzip was the dominant cost during a previous failed run.
# =============================================================================
@app.function(
image=image,
gpu="H100",
volumes={"/root/.cache/huggingface": weights_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=3600,
)
def extract_layer26_for_sequence(sequence: str, region_metadata: dict) -> dict:
"""
Run Evo 2 forward on `sequence`, capture layer-26 residual stream, return
it as a bf16-as-uint16 numpy bit-pattern plus the metadata in JSON form.
sequence: DNA string, forward strand (e.g. "ATGAA...").
We do not reverse-complement minus-strand genes β we feed
the genomic forward strand as-is. Goodfire's reference
notebook also feeds raw forward-strand sequences.
region_metadata: arbitrary dict β locus_tag, gene coords, label class,
etc. β passed through into the saved .npz so each file
is self-describing.
"""
import numpy as np
import torch
from evo2 import Evo2
# ----- load the 262k-context Evo 2 7B variant (one-time per container) ---
evo2 = Evo2(MODEL_VARIANT)
device = next(evo2.model.parameters()).device
module_dict = build_module_dict(evo2.model)
if TARGET_LAYER not in module_dict:
raise RuntimeError(f"hook target {TARGET_LAYER} not found in module tree")
target_module = module_dict[TARGET_LAYER]
# ----- register the hook -------------------------------------------------
cache: dict = {}
def hook_fn(_module, _inp, out):
# StripedHyena blocks return a tuple where index 0 is the residual-
# stream hidden state. Some sub-modules return just a tensor.
acts = out[0] if isinstance(out, tuple) else out
cache["acts"] = acts.detach() # detach so we don't keep autograd graph
handle = target_module.register_forward_hook(hook_fn)
try:
# ----- forward pass --------------------------------------------------
# Tokenizer: each nucleotide gets one token id (Evo 2's tokenizer is
# byte-level on ACGTN). Sequence length = len(sequence).
input_ids = torch.tensor(
evo2.tokenizer.tokenize(sequence),
dtype=torch.long,
).unsqueeze(0).to(device) # add batch dim, move to GPU
with torch.no_grad():
evo2.model(input_ids)
# No need for output logits β we only care about the cached activation.
acts_bf16 = cache["acts"][0] # squeeze batch dim β [seq_len, HIDDEN]
seq_len, hidden = acts_bf16.shape
assert hidden == HIDDEN, f"unexpected hidden dim {hidden}"
finally:
handle.remove()
cache.clear()
torch.cuda.empty_cache()
# ----- bf16 β uint16 bit-pattern (lossless) -----------------------------
# `view(torch.uint16)` is a zero-copy reinterpretation of the same memory:
# the bit-pattern of a bf16 float is read as the bit-pattern of a uint16.
# No precision loss. Decode on the receiving side with the inverse.
acts_uint16_np = acts_bf16.to(torch.bfloat16).view(torch.uint16).cpu().numpy()
return {
"layer26_activations_bf16": acts_uint16_np, # uint16 [seq_len, 4096]
"layer26_dtype": "bfloat16", # marker for decode
"source_layer_index": 26,
"source_layer_name": TARGET_LAYER,
"seq_len": int(seq_len),
"hidden_size": int(hidden),
"model_name": MODEL_VARIANT,
"metadata_json": json.dumps(region_metadata),
}
# =============================================================================
# Reference loader β exactly how to read one of our shared npz files back.
# This is what receivers should do; if they don't get the right shape/dtype,
# the bug is here, not upstream.
# =============================================================================
LOADER_EXAMPLE = '''
import numpy as np
import json
import torch
d = np.load("AMR/MGYG.../REGION_AMR.npz", allow_pickle=False)
# Schema (every shared file has these keys):
# layer26_activations_bf16 uint16 array, shape [seq_len, 4096]
# (bit-pattern of bf16 stored as uint16)
# layer26_dtype literal string "bfloat16"
# source_layer_index int 26
# source_layer_name literal string "blocks-26"
# seq_len, hidden_size ints (matches array shape)
# model_name literal string "evo2_7b_262k"
# metadata_json JSON-encoded dict with locus_tag, gene_symbol,
# label_class, label_subclass, gene_start/end,
# paired_with, etc.
# Decode bit-pattern to bf16, then to fp32 for downstream math:
acts_bf16 = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16)
acts_fp32 = acts_bf16.float() # shape [seq_len, 4096]
# Pull the per-region metadata:
meta = json.loads(str(d["metadata_json"]))
print(meta["gene_symbol"], meta["label_class"], meta["label_subclass"])
'''
# =============================================================================
# Reference SAE encode-and-decode that produced sane reconstruction in our
# CRISPR sanity test (5/57 of Goodfire's published features fired strongly on
# E. coli K12 CRISPR arrays). Use this to compare your own SAE handling.
#
# THE THREE PLACES WHERE PEOPLE GET THIS WRONG:
#
# 1. dtype: cast both `W_enc`/`b_enc` AND `acts` to the SAME dtype (bf16 OR
# fp32, but consistent) before the matmul. Mixed-dtype matmuls silently
# downcast in unexpected ways on some GPU paths.
#
# 2. BatchTopK is *batch-wide*, not per-token. The top-K is computed across
# the FLATTENED (seq_len * d_sae) tensor with k = K * seq_len, NOT
# `topk(k=64)` per token. Per-token topk would be ~seq_lenΓ sparser.
#
# 3. Reconstruction uses `W.T` (the transpose) not `W`. Goodfire's SAE has
# tied encoder/decoder weights, so a single `W` matrix in the state dict.
# forward = ReLU(acts @ W + b_enc); backward = features @ W.T + b_dec.
# =============================================================================
def reference_encode_and_reconstruct(acts_fp32, sae_state_dict, K=SAE_K):
"""Reference SAE encode β BatchTopK β decode.
acts_fp32: [seq_len, 4096] activations (fp32 or bf16)
sae_state_dict: loaded from `Goodfire/Evo-2-Layer-26-Mixed`
via huggingface_hub.hf_hub_download
K: BatchTopK budget per token (default 64)
Returns: (sparse_features, reconstructed_acts)
"""
import torch
# The official Goodfire checkpoint was saved with torch.compile + DDP
# prefixes β strip them when loading:
sae = {k.replace("_orig_mod.", "").replace("module.", ""): v
for k, v in sae_state_dict.items()}
W = sae["W"] # [4096, 32768]
b_enc = sae["b_enc"] # [32768]
b_dec = sae.get("b_dec", torch.zeros(W.shape[0])) # [4096]; some checkpoints omit
# Match dtypes carefully (see "place 1" above):
dtype = acts_fp32.dtype
device = acts_fp32.device
W = W.to(device=device, dtype=dtype)
b_enc = b_enc.to(device=device, dtype=dtype)
b_dec = b_dec.to(device=device, dtype=dtype)
# ----- encode (same as Arc's notebook) -----------------------------------
pre = torch.relu(acts_fp32 @ W + b_enc) # [seq_len, 32768]
# BatchTopK across the WHOLE [seq_len * d_sae] flattened tensor (place 2):
seq_len, d_sae = pre.shape
flat = pre.flatten()
numel = K * seq_len # total non-zero budget
top = torch.topk(flat, numel, dim=-1)
sparse_flat = torch.zeros_like(flat).scatter(-1, top.indices, top.values)
features = sparse_flat.reshape(pre.shape) # [seq_len, 32768], sparse
# ----- decode using W.T (place 3) ----------------------------------------
reconstructed = features @ W.T + b_dec # [seq_len, 4096]
return features, reconstructed
# =============================================================================
# Standalone smoke test you can run to verify the full pipeline end-to-end
# on a known input. If this gives weird reconstruction, the issue is upstream;
# if reconstruction is clean here but bad in your pipeline, it's downstream.
#
# Usage:
# modal run evo2_layer26_extraction.py::smoke_test
# =============================================================================
@app.function(
image=image,
gpu="H100",
volumes={"/root/.cache/huggingface": weights_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=1800,
)
def smoke_test():
"""Forward pass on a 1 kb random-ish DNA string, capture layer 26, run
SAE encode-decode, report reconstruction stats."""
import numpy as np
import torch
from evo2 import Evo2
from huggingface_hub import hf_hub_download
# 1 kb random-looking sequence β same scale as Goodfire's chr17 example
seq = "ATGAACAACGTACTGAGCGAATTCAGCAATGGCAATCGGGCTAGCTAGCTAGCTGCATGCATGCATGCATGCATGCATGCATGCAT" * 12
seq = seq[:1000]
print(f"smoke_test sequence length: {len(seq)} bp")
evo2 = Evo2(MODEL_VARIANT)
device = next(evo2.model.parameters()).device
module_dict = build_module_dict(evo2.model)
target_module = module_dict[TARGET_LAYER]
cache = {}
def hook(_, __, out):
cache["acts"] = (out[0] if isinstance(out, tuple) else out).detach()
handle = target_module.register_forward_hook(hook)
try:
input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
evo2.model(input_ids)
finally:
handle.remove()
acts = cache["acts"][0] # [seq_len, 4096] bf16
print(f"layer-26 activations: shape={tuple(acts.shape)} dtype={acts.dtype} "
f"abs_max={acts.abs().max().item():.2f} std={acts.float().std().item():.4f}")
# Load SAE and run encode + decode
sae_path = hf_hub_download(repo_id=SAE_REPO, filename=SAE_FILE)
sae_sd = torch.load(sae_path, map_location=device, weights_only=True)
features, recon = reference_encode_and_reconstruct(acts.float(), sae_sd, K=SAE_K)
# ---- reconstruction metrics ---------------------------------------------
orig = acts.float()
err = orig - recon
mse = (err ** 2).mean().item()
var = orig.var().item()
explained_variance = 1.0 - mse / max(var, 1e-9)
cosine_per_token = torch.nn.functional.cosine_similarity(orig, recon, dim=1).mean().item()
sparsity = (features != 0).float().mean().item()
print(f"\nSAE reconstruction:")
print(f" MSE: {mse:.5f}")
print(f" variance: {var:.5f}")
print(f" explained variance: {explained_variance:.4f} (closer to 1.0 is better)")
print(f" mean per-token cosine: {cosine_per_token:.4f} (closer to 1.0 is better)")
print(f" feature sparsity: {sparsity:.4f} (k/d_sae = {SAE_K/32768:.4f})")
return {
"mse": mse,
"var": var,
"explained_variance": explained_variance,
"cosine": cosine_per_token,
"sparsity": sparsity,
}
@app.local_entrypoint()
def main():
"""Run the smoke test and dump the reconstruction stats."""
r = smoke_test.remote()
print(json.dumps(r, indent=2))
# =============================================================================
# Quick reference: the original orchestrator used to extract every region.
# Each region's record was a dict with keys (`sequence`, `mag_id`,
# `locus_tag`, `region_id`, `is_positive`, `label`, `label_class`, etc.) β the
# same dict is JSON-encoded into `metadata_json` in each saved npz.
# =============================================================================
ORIGINAL_PIPELINE_NOTES = """
Source data: targeted JSONL files extracted with scripts/extract_targeted.py
Each JSONL line is one record. Fields:
sequence DNA, forward strand, gene + 2 kb upstream + 2 kb downstream flank
mag_id, locus_tag Prodigal IDs from MGnify master GFF
region_id f"{locus_tag}_{label}" β unique per record
is_positive True for AMR/STRESS/VIRULENCE positives,
False for matched negatives
label "AMR" | "STRESS" | "VIRULENCE" | "negative"
label_class AMRFinderPlus class (e.g. "BETA-LACTAM", "MACROLIDE")
label_subclass AMRFinderPlus subclass
gene_symbol e.g. "blaOXA", "catA"
pct_identity_to_ref AMRFinderPlus protein identity to reference seq
(proxy for memorisation: < 80% suggests novel allele)
paired_with locus_tag of the matched partner (positive β negative)
gene_start, gene_end, strand, contig, ext_start, ext_end
gc_content, cds_in_mobilome, negative_pool_fallback
For each record we ran `extract_layer26_for_sequence(record["sequence"], record)`
and saved the result to {label}/{mag_id}/{region_id}.npz.
Layout on the HF dataset `JG1310/mgnify-evo2-l26-full`:
AMR/{mag_id}/{region_id}.npz β AMR positive
STRESS/{mag_id}/{region_id}.npz β stress-resistance positive
VIRULENCE/{mag_id}/{region_id}.npz β virulence positive
MISC/{mag_id}/{region_id}.npz β matched-CDS negatives
(paired_with field links to the
positive in AMR/, STRESS/, etc.)
"""
|