File size: 18,496 Bytes
eb69de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
"""
Evo 2 layer-26 extraction pipeline used to produce the layer-26 npz files
shared on HuggingFace as `JG1310/mgnify-evo2-l26-full` (and the 5-layer source
on Modal volume `mgnify-embeddings-lean`).

This file is written for someone debugging SAE reconstruction error to verify:
  1. Which model variant we ran           β€” `evo2_7b_262k`, NOT `evo2_7b`.
  2. Which hook we read from              β€” whole-block output of block 26
                                             (`blocks-26`), NOT `blocks.26.mlp.l3`.
  3. How activations were stored          β€” bf16 bit-pattern in uint16 numpy array.
  4. What's inside each .npz file         β€” schema documented in `LOADER_EXAMPLE`.
  5. Reference SAE encode (BatchTopK)     β€” pattern follows Arc's official notebook
                                             at notebooks/sparse_autoencoder/
                                             sparse_autoencoder.ipynb.

If reconstruction error is bad on the receiver's side but the saved activations
match the residual stream produced by Arc's own example notebook on the same
input, the bug is in their SAE-encode/decode code (most common: missing
BatchTopK normalization, wrong dtype on matmul, wrong W vs W.T on decode).

A reproducible smoke test is provided at the bottom: run on Modal with
    modal run evo2_layer26_extraction.py::smoke_test
"""
import os
import json
import time

import modal


# =============================================================================
# Constants
# =============================================================================
MODEL_VARIANT = "evo2_7b_262k"  # 262k-context variant β€” Goodfire's SAE was
                                # trained against this, not the vanilla 7b.
TARGET_LAYER = "blocks-26"      # whole-block output (residual stream after
                                # block 26). NOT blocks-26-mlp-l3 β€” that
                                # would be a sub-module's output and would
                                # give different activations.
HIDDEN = 4096                   # Evo 2 7b residual stream dim.
SAE_REPO = "Goodfire/Evo-2-Layer-26-Mixed"
SAE_FILE = "sae-layer26-mixed-expansion_8-k_64.pt"
SAE_K = 64                      # BatchTopK budget per token-batch
                                # (k=64, expansion=8 β‡’ d_sae = 32768).


# =============================================================================
# Modal image (Arc Institute's official Evo 2 Dockerfile, translated to Modal)
# =============================================================================
image = (
    modal.Image.from_registry(
        "nvcr.io/nvidia/pytorch:25.04-py3",
        add_python=None,    # base image already has Python 3.13
    )
    .apt_install("git", "python3-pip", "python3-tomli")
    .pip_install("evo2")    # pulls flash-attn + vortex-model + huggingface_hub
)

app = modal.App("evo2-layer26-extraction-share")
weights_vol = modal.Volume.from_name("evo2-7b-weights", create_if_missing=True)


# =============================================================================
# Helper: walk the StripedHyena module tree to find a hook target by name.
# StripedHyena's nesting structure means `blocks.26` is reached via
# `evo2.model.blocks[26]`, but its child names are ('mixer', 'mlp', etc.).
# We use `named_children()` and join with '-' so that:
#     blocks-26              = block 26's container forward output
#     blocks-26-mlp-l3       = block 26's MLP last-layer linear output
# =============================================================================
def build_module_dict(model):
    module_dict = {}
    def recurse(m, prefix=""):
        for name, child in m.named_children():
            module_dict[prefix + name] = child
            recurse(child, prefix + name + "-")
    recurse(model)
    return module_dict


# =============================================================================
# The actual extraction function. This is what wrote each per-region npz.
#
# Important details to verify against your own pipeline:
#   - Forward pass receives the full sequence (gene + 2 kb upstream + 2 kb
#     downstream flank) so causal Hyena convolution gets context.
#   - Hook fires on `blocks-26.forward` and we capture `out[0]` if the output
#     is a tuple, else `out`. For StripedHyena blocks the first tuple element
#     is the residual-stream hidden state passed to block 27 β€” this is what
#     Goodfire's SAE was trained on.
#   - The captured tensor is bf16 on GPU. We keep it in bf16 and reinterpret
#     the bit-pattern as uint16 because numpy does not support bf16 natively.
#     This is a *bit-exact* reinterpretation, NOT a precision-losing cast β€”
#     decode with `torch.from_numpy(arr).view(torch.bfloat16)`.
#   - We do NOT compress (no gzip) β€” random-looking bf16 floats compress poorly
#     and gzip was the dominant cost during a previous failed run.
# =============================================================================
@app.function(
    image=image,
    gpu="H100",
    volumes={"/root/.cache/huggingface": weights_vol},
    secrets=[modal.Secret.from_name("huggingface")],
    timeout=3600,
)
def extract_layer26_for_sequence(sequence: str, region_metadata: dict) -> dict:
    """
    Run Evo 2 forward on `sequence`, capture layer-26 residual stream, return
    it as a bf16-as-uint16 numpy bit-pattern plus the metadata in JSON form.

    sequence:        DNA string, forward strand (e.g. "ATGAA...").
                     We do not reverse-complement minus-strand genes β€” we feed
                     the genomic forward strand as-is. Goodfire's reference
                     notebook also feeds raw forward-strand sequences.
    region_metadata: arbitrary dict β€” locus_tag, gene coords, label class,
                     etc. β€” passed through into the saved .npz so each file
                     is self-describing.
    """
    import numpy as np
    import torch
    from evo2 import Evo2

    # ----- load the 262k-context Evo 2 7B variant (one-time per container) ---
    evo2 = Evo2(MODEL_VARIANT)
    device = next(evo2.model.parameters()).device
    module_dict = build_module_dict(evo2.model)
    if TARGET_LAYER not in module_dict:
        raise RuntimeError(f"hook target {TARGET_LAYER} not found in module tree")
    target_module = module_dict[TARGET_LAYER]

    # ----- register the hook -------------------------------------------------
    cache: dict = {}
    def hook_fn(_module, _inp, out):
        # StripedHyena blocks return a tuple where index 0 is the residual-
        # stream hidden state. Some sub-modules return just a tensor.
        acts = out[0] if isinstance(out, tuple) else out
        cache["acts"] = acts.detach()    # detach so we don't keep autograd graph
    handle = target_module.register_forward_hook(hook_fn)

    try:
        # ----- forward pass --------------------------------------------------
        # Tokenizer: each nucleotide gets one token id (Evo 2's tokenizer is
        # byte-level on ACGTN). Sequence length = len(sequence).
        input_ids = torch.tensor(
            evo2.tokenizer.tokenize(sequence),
            dtype=torch.long,
        ).unsqueeze(0).to(device)   # add batch dim, move to GPU

        with torch.no_grad():
            evo2.model(input_ids)
        # No need for output logits β€” we only care about the cached activation.

        acts_bf16 = cache["acts"][0]  # squeeze batch dim β†’ [seq_len, HIDDEN]
        seq_len, hidden = acts_bf16.shape
        assert hidden == HIDDEN, f"unexpected hidden dim {hidden}"
    finally:
        handle.remove()
        cache.clear()
        torch.cuda.empty_cache()

    # ----- bf16 β†’ uint16 bit-pattern (lossless) -----------------------------
    # `view(torch.uint16)` is a zero-copy reinterpretation of the same memory:
    # the bit-pattern of a bf16 float is read as the bit-pattern of a uint16.
    # No precision loss. Decode on the receiving side with the inverse.
    acts_uint16_np = acts_bf16.to(torch.bfloat16).view(torch.uint16).cpu().numpy()

    return {
        "layer26_activations_bf16": acts_uint16_np,    # uint16 [seq_len, 4096]
        "layer26_dtype": "bfloat16",                   # marker for decode
        "source_layer_index": 26,
        "source_layer_name": TARGET_LAYER,
        "seq_len": int(seq_len),
        "hidden_size": int(hidden),
        "model_name": MODEL_VARIANT,
        "metadata_json": json.dumps(region_metadata),
    }


# =============================================================================
# Reference loader β€” exactly how to read one of our shared npz files back.
# This is what receivers should do; if they don't get the right shape/dtype,
# the bug is here, not upstream.
# =============================================================================
LOADER_EXAMPLE = '''
import numpy as np
import json
import torch

d = np.load("AMR/MGYG.../REGION_AMR.npz", allow_pickle=False)

# Schema (every shared file has these keys):
#   layer26_activations_bf16  uint16 array, shape [seq_len, 4096]
#                             (bit-pattern of bf16 stored as uint16)
#   layer26_dtype             literal string "bfloat16"
#   source_layer_index        int 26
#   source_layer_name         literal string "blocks-26"
#   seq_len, hidden_size      ints (matches array shape)
#   model_name                literal string "evo2_7b_262k"
#   metadata_json             JSON-encoded dict with locus_tag, gene_symbol,
#                             label_class, label_subclass, gene_start/end,
#                             paired_with, etc.

# Decode bit-pattern to bf16, then to fp32 for downstream math:
acts_bf16 = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16)
acts_fp32 = acts_bf16.float()  # shape [seq_len, 4096]

# Pull the per-region metadata:
meta = json.loads(str(d["metadata_json"]))
print(meta["gene_symbol"], meta["label_class"], meta["label_subclass"])
'''


# =============================================================================
# Reference SAE encode-and-decode that produced sane reconstruction in our
# CRISPR sanity test (5/57 of Goodfire's published features fired strongly on
# E. coli K12 CRISPR arrays). Use this to compare your own SAE handling.
#
# THE THREE PLACES WHERE PEOPLE GET THIS WRONG:
#
#   1. dtype: cast both `W_enc`/`b_enc` AND `acts` to the SAME dtype (bf16 OR
#      fp32, but consistent) before the matmul. Mixed-dtype matmuls silently
#      downcast in unexpected ways on some GPU paths.
#
#   2. BatchTopK is *batch-wide*, not per-token. The top-K is computed across
#      the FLATTENED (seq_len * d_sae) tensor with k = K * seq_len, NOT
#      `topk(k=64)` per token. Per-token topk would be ~seq_lenΓ— sparser.
#
#   3. Reconstruction uses `W.T` (the transpose) not `W`. Goodfire's SAE has
#      tied encoder/decoder weights, so a single `W` matrix in the state dict.
#      forward = ReLU(acts @ W + b_enc); backward = features @ W.T + b_dec.
# =============================================================================
def reference_encode_and_reconstruct(acts_fp32, sae_state_dict, K=SAE_K):
    """Reference SAE encode β†’ BatchTopK β†’ decode.

    acts_fp32:        [seq_len, 4096] activations (fp32 or bf16)
    sae_state_dict:   loaded from `Goodfire/Evo-2-Layer-26-Mixed`
                      via huggingface_hub.hf_hub_download
    K:                BatchTopK budget per token (default 64)

    Returns: (sparse_features, reconstructed_acts)
    """
    import torch

    # The official Goodfire checkpoint was saved with torch.compile + DDP
    # prefixes β€” strip them when loading:
    sae = {k.replace("_orig_mod.", "").replace("module.", ""): v
           for k, v in sae_state_dict.items()}
    W = sae["W"]                                         # [4096, 32768]
    b_enc = sae["b_enc"]                                 # [32768]
    b_dec = sae.get("b_dec", torch.zeros(W.shape[0]))    # [4096]; some checkpoints omit

    # Match dtypes carefully (see "place 1" above):
    dtype = acts_fp32.dtype
    device = acts_fp32.device
    W = W.to(device=device, dtype=dtype)
    b_enc = b_enc.to(device=device, dtype=dtype)
    b_dec = b_dec.to(device=device, dtype=dtype)

    # ----- encode (same as Arc's notebook) -----------------------------------
    pre = torch.relu(acts_fp32 @ W + b_enc)              # [seq_len, 32768]

    # BatchTopK across the WHOLE [seq_len * d_sae] flattened tensor (place 2):
    seq_len, d_sae = pre.shape
    flat = pre.flatten()
    numel = K * seq_len                                  # total non-zero budget
    top = torch.topk(flat, numel, dim=-1)
    sparse_flat = torch.zeros_like(flat).scatter(-1, top.indices, top.values)
    features = sparse_flat.reshape(pre.shape)            # [seq_len, 32768], sparse

    # ----- decode using W.T (place 3) ----------------------------------------
    reconstructed = features @ W.T + b_dec               # [seq_len, 4096]

    return features, reconstructed


# =============================================================================
# Standalone smoke test you can run to verify the full pipeline end-to-end
# on a known input. If this gives weird reconstruction, the issue is upstream;
# if reconstruction is clean here but bad in your pipeline, it's downstream.
#
# Usage:
#   modal run evo2_layer26_extraction.py::smoke_test
# =============================================================================
@app.function(
    image=image,
    gpu="H100",
    volumes={"/root/.cache/huggingface": weights_vol},
    secrets=[modal.Secret.from_name("huggingface")],
    timeout=1800,
)
def smoke_test():
    """Forward pass on a 1 kb random-ish DNA string, capture layer 26, run
    SAE encode-decode, report reconstruction stats."""
    import numpy as np
    import torch
    from evo2 import Evo2
    from huggingface_hub import hf_hub_download

    # 1 kb random-looking sequence β€” same scale as Goodfire's chr17 example
    seq = "ATGAACAACGTACTGAGCGAATTCAGCAATGGCAATCGGGCTAGCTAGCTAGCTGCATGCATGCATGCATGCATGCATGCATGCAT" * 12
    seq = seq[:1000]
    print(f"smoke_test sequence length: {len(seq)} bp")

    evo2 = Evo2(MODEL_VARIANT)
    device = next(evo2.model.parameters()).device
    module_dict = build_module_dict(evo2.model)
    target_module = module_dict[TARGET_LAYER]

    cache = {}
    def hook(_, __, out):
        cache["acts"] = (out[0] if isinstance(out, tuple) else out).detach()
    handle = target_module.register_forward_hook(hook)
    try:
        input_ids = torch.tensor(evo2.tokenizer.tokenize(seq), dtype=torch.long).unsqueeze(0).to(device)
        with torch.no_grad():
            evo2.model(input_ids)
    finally:
        handle.remove()
    acts = cache["acts"][0]          # [seq_len, 4096] bf16
    print(f"layer-26 activations: shape={tuple(acts.shape)} dtype={acts.dtype}  "
          f"abs_max={acts.abs().max().item():.2f}  std={acts.float().std().item():.4f}")

    # Load SAE and run encode + decode
    sae_path = hf_hub_download(repo_id=SAE_REPO, filename=SAE_FILE)
    sae_sd = torch.load(sae_path, map_location=device, weights_only=True)

    features, recon = reference_encode_and_reconstruct(acts.float(), sae_sd, K=SAE_K)

    # ---- reconstruction metrics ---------------------------------------------
    orig = acts.float()
    err = orig - recon
    mse = (err ** 2).mean().item()
    var = orig.var().item()
    explained_variance = 1.0 - mse / max(var, 1e-9)
    cosine_per_token = torch.nn.functional.cosine_similarity(orig, recon, dim=1).mean().item()
    sparsity = (features != 0).float().mean().item()
    print(f"\nSAE reconstruction:")
    print(f"  MSE:                   {mse:.5f}")
    print(f"  variance:              {var:.5f}")
    print(f"  explained variance:    {explained_variance:.4f}  (closer to 1.0 is better)")
    print(f"  mean per-token cosine: {cosine_per_token:.4f}  (closer to 1.0 is better)")
    print(f"  feature sparsity:      {sparsity:.4f}  (k/d_sae = {SAE_K/32768:.4f})")
    return {
        "mse": mse,
        "var": var,
        "explained_variance": explained_variance,
        "cosine": cosine_per_token,
        "sparsity": sparsity,
    }


@app.local_entrypoint()
def main():
    """Run the smoke test and dump the reconstruction stats."""
    r = smoke_test.remote()
    print(json.dumps(r, indent=2))


# =============================================================================
# Quick reference: the original orchestrator used to extract every region.
# Each region's record was a dict with keys (`sequence`, `mag_id`,
# `locus_tag`, `region_id`, `is_positive`, `label`, `label_class`, etc.) β€” the
# same dict is JSON-encoded into `metadata_json` in each saved npz.
# =============================================================================
ORIGINAL_PIPELINE_NOTES = """
Source data: targeted JSONL files extracted with scripts/extract_targeted.py
  Each JSONL line is one record. Fields:
      sequence            DNA, forward strand, gene + 2 kb upstream + 2 kb downstream flank
      mag_id, locus_tag   Prodigal IDs from MGnify master GFF
      region_id           f"{locus_tag}_{label}" β€” unique per record
      is_positive         True for AMR/STRESS/VIRULENCE positives,
                          False for matched negatives
      label               "AMR" | "STRESS" | "VIRULENCE" | "negative"
      label_class         AMRFinderPlus class (e.g. "BETA-LACTAM", "MACROLIDE")
      label_subclass      AMRFinderPlus subclass
      gene_symbol         e.g. "blaOXA", "catA"
      pct_identity_to_ref AMRFinderPlus protein identity to reference seq
                          (proxy for memorisation: < 80% suggests novel allele)
      paired_with         locus_tag of the matched partner (positive ↔ negative)
      gene_start, gene_end, strand, contig, ext_start, ext_end
      gc_content, cds_in_mobilome, negative_pool_fallback

For each record we ran `extract_layer26_for_sequence(record["sequence"], record)`
and saved the result to {label}/{mag_id}/{region_id}.npz.

Layout on the HF dataset `JG1310/mgnify-evo2-l26-full`:
  AMR/{mag_id}/{region_id}.npz         β€” AMR positive
  STRESS/{mag_id}/{region_id}.npz      β€” stress-resistance positive
  VIRULENCE/{mag_id}/{region_id}.npz   β€” virulence positive
  MISC/{mag_id}/{region_id}.npz        β€” matched-CDS negatives
                                         (paired_with field links to the
                                          positive in AMR/, STRESS/, etc.)
"""