mgnify-evo2-probes / code /modal /push_probe_share.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""
Push the probe artifacts (code + plots + manifests + scores + summaries +
docs + tiny checkpoint weights) to a HuggingFace model repo for teammates.
Excludes raw embeddings (huge, on Modal volumes) and FASTQ zips.
Usage:
modal run modal/push_probe_share.py::main
modal run modal/push_probe_share.py::main --repo-name foo --private
"""
from __future__ import annotations
import json
import os
import shutil
import time
from pathlib import Path
import modal
# Local source directories β€” mounted into the container
LOCAL_PROBES = "/home/ror25cal/MGnify/probes"
LOCAL_SCRIPTS = "/home/ror25cal/MGnify/scripts"
LOCAL_SHARE = "/home/ror25cal/MGnify/share"
LOCAL_MODAL = "/home/ror25cal/MGnify/modal"
LOCAL_TARGETED_JSONL = "/home/ror25cal/MGnify/data/targeted_jsonl"
LOCAL_JOURNAL = "/home/ror25cal/MGnify/JOURNAL.md"
LOCAL_STATUS = "/home/ror25cal/MGnify/HACKATHON_STATUS.md"
image = (
modal.Image.debian_slim()
.pip_install("huggingface_hub>=0.25")
.add_local_dir(LOCAL_PROBES, remote_path="/local/probes")
.add_local_dir(LOCAL_SCRIPTS, remote_path="/local/scripts")
.add_local_dir(LOCAL_SHARE, remote_path="/local/share")
.add_local_dir(LOCAL_MODAL, remote_path="/local/modal")
.add_local_dir(LOCAL_TARGETED_JSONL, remote_path="/local/data/targeted_jsonl")
.add_local_file(LOCAL_JOURNAL, remote_path="/local/JOURNAL.md")
.add_local_file(LOCAL_STATUS, remote_path="/local/HACKATHON_STATUS.md")
)
results_vol = modal.Volume.from_name("mgnify-probe-results", create_if_missing=False)
app = modal.App("mgnify-probe-share-push")
@app.function(
image=image,
cpu=2,
memory=4 * 1024,
volumes={"/results": results_vol},
secrets=[modal.Secret.from_name("huggingface")],
timeout=1800,
)
def push_share(repo_name: str = "mgnify-evo2-probes", private: bool = False) -> dict:
import glob
import io
from huggingface_hub import HfApi
# ---- HF auth (workaround for swapped-key Modal secret) ----
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if not token:
for k, v in os.environ.items():
if k.startswith("hf_") and v == "HF_TOKEN":
token = k
break
if not token:
raise RuntimeError("HF_TOKEN missing")
api = HfApi(token=token)
me = api.whoami()
user = me["name"]
repo_id = f"{user}/{repo_name}"
print(f"[push] target: {repo_id} (private={private})")
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
# ---- Stage everything into /tmp/repo_stage ----
stage = "/tmp/repo_stage"
if os.path.exists(stage):
shutil.rmtree(stage)
os.makedirs(stage)
def copy_glob(src_glob, dst_dir):
os.makedirs(dst_dir, exist_ok=True)
for src in glob.glob(src_glob, recursive=True):
if os.path.isfile(src):
rel = os.path.basename(src)
shutil.copy(src, f"{dst_dir}/{rel}")
def copy_tree_filtered(src_root: str, dst_root: str, include_globs: list[str], exclude_substrings: list[str] = None):
"""Copy files matching include_globs, preserving relative path."""
os.makedirs(dst_root, exist_ok=True)
exclude_substrings = exclude_substrings or []
for inc in include_globs:
for src in glob.glob(f"{src_root}/{inc}", recursive=True):
if not os.path.isfile(src):
continue
rel = os.path.relpath(src, src_root)
if any(es in rel for es in exclude_substrings):
continue
dst = f"{dst_root}/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(src, dst)
print("[stage] code …")
copy_tree_filtered("/local/probes", f"{stage}/code/probes",
["*.py", "*.md"],
exclude_substrings=["__pycache__"])
copy_tree_filtered("/local/scripts", f"{stage}/code/scripts",
["*.py", "*.md", "*.sh"],
exclude_substrings=["__pycache__"])
copy_tree_filtered("/local/share", f"{stage}/code/share", ["*.py", "*.md"])
copy_tree_filtered("/local/modal", f"{stage}/code/modal",
["*.py", "*.md"], exclude_substrings=["__pycache__"])
print("[stage] manifests …")
copy_glob("/local/probes/splits/*.json", f"{stage}/manifests")
print("[stage] plots …")
copy_glob("/local/probes/results/*.png", f"{stage}/plots")
print("[stage] scores + summaries + training metrics …")
copy_glob("/local/probes/results/*.scores.jsonl", f"{stage}/scores")
copy_glob("/local/probes/results/*.summary.json", f"{stage}/summaries")
copy_glob("/local/probes/results/*.json", f"{stage}/summaries")
# also pull v1_*/metrics.json
for f in glob.glob("/local/probes/results/v1_*/*.json"):
rel = os.path.relpath(f, "/local/probes/results")
dst = f"{stage}/training_metrics/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(f, dst)
print("[stage] extraction JSONLs …")
copy_tree_filtered("/local/data/targeted_jsonl", f"{stage}/data/targeted_jsonl",
["**/*.jsonl"])
print("[stage] docs …")
shutil.copy("/local/JOURNAL.md", f"{stage}/JOURNAL.md")
shutil.copy("/local/HACKATHON_STATUS.md", f"{stage}/HACKATHON_STATUS.md")
print("[stage] probe checkpoints from Modal volume …")
ckpt_dir = f"{stage}/checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
for ckpt_path in glob.glob("/results/**/checkpoint.pt", recursive=True):
rel = os.path.relpath(ckpt_path, "/results")
dst = f"{ckpt_dir}/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(ckpt_path, dst)
# also include the per-epoch checkpoints (for VFDB) and the diff_of_means.npy
for f in glob.glob("/results/**/checkpoint_epoch*.pt", recursive=True):
rel = os.path.relpath(f, "/results")
dst = f"{ckpt_dir}/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(f, dst)
for f in glob.glob("/results/**/diff_of_means.npy", recursive=True):
rel = os.path.relpath(f, "/results")
dst = f"{ckpt_dir}/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(f, dst)
for f in glob.glob("/results/**/metrics.json", recursive=True):
rel = os.path.relpath(f, "/results")
dst = f"{ckpt_dir}/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(f, dst)
for f in glob.glob("/results/**/val_history.json", recursive=True):
rel = os.path.relpath(f, "/results")
dst = f"{ckpt_dir}/{rel}"
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(f, dst)
# ---- README ----
print("[stage] README …")
readme = README_TEMPLATE.format(repo_id=repo_id)
Path(f"{stage}/README.md").write_text(readme)
# ---- Inventory + size ----
total = 0
n_files = 0
for root, _, files in os.walk(stage):
for f in files:
total += os.path.getsize(os.path.join(root, f))
n_files += 1
print(f"[stage] {n_files} files, {total/1e6:.1f} MB")
# ---- Push ----
print(f"[push] uploading {n_files} files to {repo_id} …")
t0 = time.time()
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=stage,
commit_message="Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs",
)
elapsed = time.time() - t0
print(f"[push] done in {elapsed:.0f}s")
return {
"repo_url": f"https://huggingface.co/{repo_id}",
"n_files": n_files,
"total_mb": total / 1e6,
"elapsed_s": elapsed,
}
README_TEMPLATE = """\
---
license: mit
tags:
- evo2
- bioinformatics
- mgnify
- amr
- probe
- interpretability
---
# MGnify Γ— Evo 2 layer-26 probes
Linear and attention probes on the residual-stream of the Evo 2 7B-262k DNA
foundation model (layer 26), trained to detect AMR genes, virulence factors,
and AMR subclasses in MGnify metagenomes.
Repo at {repo_id}.
## What's here
- **`code/`** β€” all the Python / Modal pipeline code (extraction, embedding,
probe training, plot scripts).
- **`manifests/`** β€” train / val / test split JSONs (MAG-level for AMR,
species-level for VFDB virulence).
- **`checkpoints/`** β€” trained probe weights (each is tiny β€” `Linear(4096, 1)`
is 4097 params). Saved as PyTorch `.pt` state-dicts.
- `amr_binary_v1/linear/{{run_id}}/checkpoint.pt` β€” primary AMR probe
- `amr_binary_v1/attention/{{run_id}}/checkpoint.pt` β€” attention variant
- `amr_class5_v1/linear/{{CLASS}}/{{run_id}}/checkpoint.pt` β€” 5 per-class probes
- `vfdb_virulence_v1/linear/{{run_id}}/checkpoint.pt` β€” virulence probe
- **`plots/`** β€” score-distribution histograms, sanity plots, attention
visualisations (PNG).
- **`scores/`** β€” per-region / per-read raw probe logits (JSONL) β€” useful for
reformatting plots without retraining.
- **`summaries/`** β€” AUC + best-F1 + thresholds per probe (JSON).
- **`training_metrics/`** β€” full per-epoch val history (JSON).
- **`data/targeted_jsonl/`** β€” extraction outputs (gene + 2 kb flank
sequences, paired-with negatives, etc.) β€” the *inputs* to the embedding
pipeline. Big text files (~70 MB total).
- **`JOURNAL.md`** β€” chronological project log.
- **`HACKATHON_STATUS.md`** β€” high-level status doc.
## What's NOT here
- Raw Evo 2 embeddings (the per-token layer-26 activations, several hundred
GB total). These live on Modal Volumes:
- `mgnify-embeddings-l26-lean` (5483 records Γ— 37 MB = ~208 GB)
- `mgnify-embeddings-l26-vfdb` (~150 GB)
- `mgnify-embeddings-l26-human-viral` (~60 GB)
- A subset is mirrored as a public HF dataset:
`JG1310/mgnify-evo2-l26-full` β€” layer-26 npz for all MGnify-extracted regions.
- FASTQ files, raw input xlsx, etc.
## Loading a probe checkpoint
```python
import torch
import torch.nn as nn
# AMR linear probe
sd = torch.load("checkpoints/amr_binary_v1/linear/<run_id>/checkpoint.pt", weights_only=True)
probe = nn.Linear(4096, 1)
probe.load_state_dict(sd)
probe.eval()
# Apply to a per-token activation tensor [seq_len, 4096]
# (e.g. from a layer-26-extracted .npz file)
import numpy as np
d = np.load("some_region.npz", allow_pickle=False)
acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float()
with torch.no_grad():
logits = probe(acts).squeeze(-1) # [seq_len] per-token logits
print(f"max-pool logit: {{logits.max():.2f}}")
print(f"mean-pool logit: {{logits.mean():.2f}}")
```
## Reading the manifests
```python
import json
m = json.load(open("manifests/amr_binary_v1.json"))
# m['region_split'][region_id] -> "train" / "val" / "test"
# m['labels_per_region'][region_id] -> 0 or 1
# m['gene_coords'][region_id] -> [gene_start, gene_end, ext_start, ext_end, strand]
# m['pair_partner'][region_id] -> matched-pair region_id
```
## Headline results (test sets)
| Probe | Eval | AUC |
|---|---|---|
| Linear, AMR-vs-MISC | 672 regions, MAG-level held out | 0.949 (region max-pool) |
| Attention, AMR-vs-MISC | 672 regions | 0.977 |
| Linear, AMR class-specific (top 5) | within-AMR class | 0.989 - 0.998 |
| Linear, VFDB virulence | 336 regions, species-level held out | 0.833 (region mean-pool) |
| Linear AMR probe β†’ multi-org short reads (301 bp) | 1340 reads | 0.898 (per-read mean) / 0.921 (per-CDS) |
See `JOURNAL.md` for the full story; `summaries/` for raw numbers.
"""
@app.local_entrypoint()
def main(repo_name: str = "mgnify-evo2-probes", private: bool = False):
print(f"[local] launching push to HF repo: {repo_name} (private={private})")
r = push_share.remote(repo_name=repo_name, private=private)
print("\n=== PUSHED ===")
print(f" URL: {r['repo_url']}")
print(f" files: {r['n_files']}")
print(f" size: {r['total_mb']:.1f} MB")
print(f" elapsed: {r['elapsed_s']:.0f} s")