MGnify Γ Evo 2 layer-26 probes
Linear and attention probes on the residual-stream of the Evo 2 7B-262k DNA foundation model (layer 26), trained to detect AMR genes, virulence factors, and AMR subclasses in MGnify metagenomes.
Repo at JG1310/mgnify-evo2-probes.
What's here
code/β all the Python / Modal pipeline code (extraction, embedding, probe training, plot scripts).manifests/β train / val / test split JSONs (MAG-level for AMR, species-level for VFDB virulence).checkpoints/β trained probe weights (each is tiny βLinear(4096, 1)is 4097 params). Saved as PyTorch.ptstate-dicts.amr_binary_v1/linear/{run_id}/checkpoint.ptβ primary AMR probeamr_binary_v1/attention/{run_id}/checkpoint.ptβ attention variantamr_class5_v1/linear/{CLASS}/{run_id}/checkpoint.ptβ 5 per-class probesvfdb_virulence_v1/linear/{run_id}/checkpoint.ptβ virulence probe
plots/β score-distribution histograms, sanity plots, attention visualisations (PNG).scores/β per-region / per-read raw probe logits (JSONL) β useful for reformatting plots without retraining.summaries/β AUC + best-F1 + thresholds per probe (JSON).training_metrics/β full per-epoch val history (JSON).data/targeted_jsonl/β extraction outputs (gene + 2 kb flank sequences, paired-with negatives, etc.) β the inputs to the embedding pipeline. Big text files (~70 MB total).JOURNAL.mdβ chronological project log.HACKATHON_STATUS.mdβ high-level status doc.
What's NOT here
- Raw Evo 2 embeddings (the per-token layer-26 activations, several hundred
GB total). These live on Modal Volumes:
mgnify-embeddings-l26-lean(5483 records Γ 37 MB = ~208 GB)mgnify-embeddings-l26-vfdb(~150 GB)mgnify-embeddings-l26-human-viral(~60 GB)
- A subset is mirrored as a public HF dataset:
JG1310/mgnify-evo2-l26-fullβ layer-26 npz for all MGnify-extracted regions. - FASTQ files, raw input xlsx, etc.
Loading a probe checkpoint
import torch
import torch.nn as nn
# AMR linear probe
sd = torch.load("checkpoints/amr_binary_v1/linear/<run_id>/checkpoint.pt", weights_only=True)
probe = nn.Linear(4096, 1)
probe.load_state_dict(sd)
probe.eval()
# Apply to a per-token activation tensor [seq_len, 4096]
# (e.g. from a layer-26-extracted .npz file)
import numpy as np
d = np.load("some_region.npz", allow_pickle=False)
acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float()
with torch.no_grad():
logits = probe(acts).squeeze(-1) # [seq_len] per-token logits
print(f"max-pool logit: {logits.max():.2f}")
print(f"mean-pool logit: {logits.mean():.2f}")
Reading the manifests
import json
m = json.load(open("manifests/amr_binary_v1.json"))
# m['region_split'][region_id] -> "train" / "val" / "test"
# m['labels_per_region'][region_id] -> 0 or 1
# m['gene_coords'][region_id] -> [gene_start, gene_end, ext_start, ext_end, strand]
# m['pair_partner'][region_id] -> matched-pair region_id
Headline results (test sets)
| Probe | Eval | AUC |
|---|---|---|
| Linear, AMR-vs-MISC | 672 regions, MAG-level held out | 0.949 (region max-pool) |
| Attention, AMR-vs-MISC | 672 regions | 0.977 |
| Linear, AMR class-specific (top 5) | within-AMR class | 0.989 - 0.998 |
| Linear, VFDB virulence | 336 regions, species-level held out | 0.833 (region mean-pool) |
| Linear AMR probe β multi-org short reads (301 bp) | 1340 reads | 0.898 (per-read mean) / 0.921 (per-CDS) |
See JOURNAL.md for the full story; summaries/ for raw numbers.
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support