MGnify Γ— Evo 2 layer-26 probes

Linear and attention probes on the residual-stream of the Evo 2 7B-262k DNA foundation model (layer 26), trained to detect AMR genes, virulence factors, and AMR subclasses in MGnify metagenomes.

Repo at JG1310/mgnify-evo2-probes.

What's here

  • code/ β€” all the Python / Modal pipeline code (extraction, embedding, probe training, plot scripts).
  • manifests/ β€” train / val / test split JSONs (MAG-level for AMR, species-level for VFDB virulence).
  • checkpoints/ β€” trained probe weights (each is tiny β€” Linear(4096, 1) is 4097 params). Saved as PyTorch .pt state-dicts.
    • amr_binary_v1/linear/{run_id}/checkpoint.pt β€” primary AMR probe
    • amr_binary_v1/attention/{run_id}/checkpoint.pt β€” attention variant
    • amr_class5_v1/linear/{CLASS}/{run_id}/checkpoint.pt β€” 5 per-class probes
    • vfdb_virulence_v1/linear/{run_id}/checkpoint.pt β€” virulence probe
  • plots/ β€” score-distribution histograms, sanity plots, attention visualisations (PNG).
  • scores/ β€” per-region / per-read raw probe logits (JSONL) β€” useful for reformatting plots without retraining.
  • summaries/ β€” AUC + best-F1 + thresholds per probe (JSON).
  • training_metrics/ β€” full per-epoch val history (JSON).
  • data/targeted_jsonl/ β€” extraction outputs (gene + 2 kb flank sequences, paired-with negatives, etc.) β€” the inputs to the embedding pipeline. Big text files (~70 MB total).
  • JOURNAL.md β€” chronological project log.
  • HACKATHON_STATUS.md β€” high-level status doc.

What's NOT here

  • Raw Evo 2 embeddings (the per-token layer-26 activations, several hundred GB total). These live on Modal Volumes:
    • mgnify-embeddings-l26-lean (5483 records Γ— 37 MB = ~208 GB)
    • mgnify-embeddings-l26-vfdb (~150 GB)
    • mgnify-embeddings-l26-human-viral (~60 GB)
  • A subset is mirrored as a public HF dataset: JG1310/mgnify-evo2-l26-full β€” layer-26 npz for all MGnify-extracted regions.
  • FASTQ files, raw input xlsx, etc.

Loading a probe checkpoint

import torch
import torch.nn as nn

# AMR linear probe
sd = torch.load("checkpoints/amr_binary_v1/linear/<run_id>/checkpoint.pt", weights_only=True)
probe = nn.Linear(4096, 1)
probe.load_state_dict(sd)
probe.eval()

# Apply to a per-token activation tensor [seq_len, 4096]
# (e.g. from a layer-26-extracted .npz file)
import numpy as np
d = np.load("some_region.npz", allow_pickle=False)
acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float()
with torch.no_grad():
    logits = probe(acts).squeeze(-1)         # [seq_len] per-token logits
print(f"max-pool logit: {logits.max():.2f}")
print(f"mean-pool logit: {logits.mean():.2f}")

Reading the manifests

import json
m = json.load(open("manifests/amr_binary_v1.json"))
# m['region_split'][region_id]    -> "train" / "val" / "test"
# m['labels_per_region'][region_id] -> 0 or 1
# m['gene_coords'][region_id]       -> [gene_start, gene_end, ext_start, ext_end, strand]
# m['pair_partner'][region_id]      -> matched-pair region_id

Headline results (test sets)

Probe Eval AUC
Linear, AMR-vs-MISC 672 regions, MAG-level held out 0.949 (region max-pool)
Attention, AMR-vs-MISC 672 regions 0.977
Linear, AMR class-specific (top 5) within-AMR class 0.989 - 0.998
Linear, VFDB virulence 336 regions, species-level held out 0.833 (region mean-pool)
Linear AMR probe β†’ multi-org short reads (301 bp) 1340 reads 0.898 (per-read mean) / 0.921 (per-CDS)

See JOURNAL.md for the full story; summaries/ for raw numbers.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support