| """ |
| Push the probe artifacts (code + plots + manifests + scores + summaries + |
| docs + tiny checkpoint weights) to a HuggingFace model repo for teammates. |
| |
| Excludes raw embeddings (huge, on Modal volumes) and FASTQ zips. |
| |
| Usage: |
| modal run modal/push_probe_share.py::main |
| modal run modal/push_probe_share.py::main --repo-name foo --private |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import shutil |
| import time |
| from pathlib import Path |
|
|
| import modal |
|
|
|
|
| |
| LOCAL_PROBES = "/home/ror25cal/MGnify/probes" |
| LOCAL_SCRIPTS = "/home/ror25cal/MGnify/scripts" |
| LOCAL_SHARE = "/home/ror25cal/MGnify/share" |
| LOCAL_MODAL = "/home/ror25cal/MGnify/modal" |
| LOCAL_TARGETED_JSONL = "/home/ror25cal/MGnify/data/targeted_jsonl" |
| LOCAL_JOURNAL = "/home/ror25cal/MGnify/JOURNAL.md" |
| LOCAL_STATUS = "/home/ror25cal/MGnify/HACKATHON_STATUS.md" |
|
|
|
|
| image = ( |
| modal.Image.debian_slim() |
| .pip_install("huggingface_hub>=0.25") |
| .add_local_dir(LOCAL_PROBES, remote_path="/local/probes") |
| .add_local_dir(LOCAL_SCRIPTS, remote_path="/local/scripts") |
| .add_local_dir(LOCAL_SHARE, remote_path="/local/share") |
| .add_local_dir(LOCAL_MODAL, remote_path="/local/modal") |
| .add_local_dir(LOCAL_TARGETED_JSONL, remote_path="/local/data/targeted_jsonl") |
| .add_local_file(LOCAL_JOURNAL, remote_path="/local/JOURNAL.md") |
| .add_local_file(LOCAL_STATUS, remote_path="/local/HACKATHON_STATUS.md") |
| ) |
| results_vol = modal.Volume.from_name("mgnify-probe-results", create_if_missing=False) |
| app = modal.App("mgnify-probe-share-push") |
|
|
|
|
| @app.function( |
| image=image, |
| cpu=2, |
| memory=4 * 1024, |
| volumes={"/results": results_vol}, |
| secrets=[modal.Secret.from_name("huggingface")], |
| timeout=1800, |
| ) |
| def push_share(repo_name: str = "mgnify-evo2-probes", private: bool = False) -> dict: |
| import glob |
| import io |
| from huggingface_hub import HfApi |
|
|
| |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| if not token: |
| for k, v in os.environ.items(): |
| if k.startswith("hf_") and v == "HF_TOKEN": |
| token = k |
| break |
| if not token: |
| raise RuntimeError("HF_TOKEN missing") |
| api = HfApi(token=token) |
| me = api.whoami() |
| user = me["name"] |
| repo_id = f"{user}/{repo_name}" |
| print(f"[push] target: {repo_id} (private={private})") |
| api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) |
|
|
| |
| stage = "/tmp/repo_stage" |
| if os.path.exists(stage): |
| shutil.rmtree(stage) |
| os.makedirs(stage) |
|
|
| def copy_glob(src_glob, dst_dir): |
| os.makedirs(dst_dir, exist_ok=True) |
| for src in glob.glob(src_glob, recursive=True): |
| if os.path.isfile(src): |
| rel = os.path.basename(src) |
| shutil.copy(src, f"{dst_dir}/{rel}") |
|
|
| def copy_tree_filtered(src_root: str, dst_root: str, include_globs: list[str], exclude_substrings: list[str] = None): |
| """Copy files matching include_globs, preserving relative path.""" |
| os.makedirs(dst_root, exist_ok=True) |
| exclude_substrings = exclude_substrings or [] |
| for inc in include_globs: |
| for src in glob.glob(f"{src_root}/{inc}", recursive=True): |
| if not os.path.isfile(src): |
| continue |
| rel = os.path.relpath(src, src_root) |
| if any(es in rel for es in exclude_substrings): |
| continue |
| dst = f"{dst_root}/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(src, dst) |
|
|
| print("[stage] code β¦") |
| copy_tree_filtered("/local/probes", f"{stage}/code/probes", |
| ["*.py", "*.md"], |
| exclude_substrings=["__pycache__"]) |
| copy_tree_filtered("/local/scripts", f"{stage}/code/scripts", |
| ["*.py", "*.md", "*.sh"], |
| exclude_substrings=["__pycache__"]) |
| copy_tree_filtered("/local/share", f"{stage}/code/share", ["*.py", "*.md"]) |
| copy_tree_filtered("/local/modal", f"{stage}/code/modal", |
| ["*.py", "*.md"], exclude_substrings=["__pycache__"]) |
|
|
| print("[stage] manifests β¦") |
| copy_glob("/local/probes/splits/*.json", f"{stage}/manifests") |
|
|
| print("[stage] plots β¦") |
| copy_glob("/local/probes/results/*.png", f"{stage}/plots") |
|
|
| print("[stage] scores + summaries + training metrics β¦") |
| copy_glob("/local/probes/results/*.scores.jsonl", f"{stage}/scores") |
| copy_glob("/local/probes/results/*.summary.json", f"{stage}/summaries") |
| copy_glob("/local/probes/results/*.json", f"{stage}/summaries") |
| |
| for f in glob.glob("/local/probes/results/v1_*/*.json"): |
| rel = os.path.relpath(f, "/local/probes/results") |
| dst = f"{stage}/training_metrics/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(f, dst) |
|
|
| print("[stage] extraction JSONLs β¦") |
| copy_tree_filtered("/local/data/targeted_jsonl", f"{stage}/data/targeted_jsonl", |
| ["**/*.jsonl"]) |
|
|
| print("[stage] docs β¦") |
| shutil.copy("/local/JOURNAL.md", f"{stage}/JOURNAL.md") |
| shutil.copy("/local/HACKATHON_STATUS.md", f"{stage}/HACKATHON_STATUS.md") |
|
|
| print("[stage] probe checkpoints from Modal volume β¦") |
| ckpt_dir = f"{stage}/checkpoints" |
| os.makedirs(ckpt_dir, exist_ok=True) |
| for ckpt_path in glob.glob("/results/**/checkpoint.pt", recursive=True): |
| rel = os.path.relpath(ckpt_path, "/results") |
| dst = f"{ckpt_dir}/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(ckpt_path, dst) |
| |
| for f in glob.glob("/results/**/checkpoint_epoch*.pt", recursive=True): |
| rel = os.path.relpath(f, "/results") |
| dst = f"{ckpt_dir}/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(f, dst) |
| for f in glob.glob("/results/**/diff_of_means.npy", recursive=True): |
| rel = os.path.relpath(f, "/results") |
| dst = f"{ckpt_dir}/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(f, dst) |
| for f in glob.glob("/results/**/metrics.json", recursive=True): |
| rel = os.path.relpath(f, "/results") |
| dst = f"{ckpt_dir}/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(f, dst) |
| for f in glob.glob("/results/**/val_history.json", recursive=True): |
| rel = os.path.relpath(f, "/results") |
| dst = f"{ckpt_dir}/{rel}" |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| shutil.copy(f, dst) |
|
|
| |
| print("[stage] README β¦") |
| readme = README_TEMPLATE.format(repo_id=repo_id) |
| Path(f"{stage}/README.md").write_text(readme) |
|
|
| |
| total = 0 |
| n_files = 0 |
| for root, _, files in os.walk(stage): |
| for f in files: |
| total += os.path.getsize(os.path.join(root, f)) |
| n_files += 1 |
| print(f"[stage] {n_files} files, {total/1e6:.1f} MB") |
|
|
| |
| print(f"[push] uploading {n_files} files to {repo_id} β¦") |
| t0 = time.time() |
| api.upload_folder( |
| repo_id=repo_id, |
| repo_type="model", |
| folder_path=stage, |
| commit_message="Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs", |
| ) |
| elapsed = time.time() - t0 |
| print(f"[push] done in {elapsed:.0f}s") |
| return { |
| "repo_url": f"https://huggingface.co/{repo_id}", |
| "n_files": n_files, |
| "total_mb": total / 1e6, |
| "elapsed_s": elapsed, |
| } |
|
|
|
|
| README_TEMPLATE = """\ |
| --- |
| license: mit |
| tags: |
| - evo2 |
| - bioinformatics |
| - mgnify |
| - amr |
| - probe |
| - interpretability |
| --- |
| |
| # MGnify Γ Evo 2 layer-26 probes |
| |
| Linear and attention probes on the residual-stream of the Evo 2 7B-262k DNA |
| foundation model (layer 26), trained to detect AMR genes, virulence factors, |
| and AMR subclasses in MGnify metagenomes. |
| |
| Repo at {repo_id}. |
| |
| ## What's here |
| |
| - **`code/`** β all the Python / Modal pipeline code (extraction, embedding, |
| probe training, plot scripts). |
| - **`manifests/`** β train / val / test split JSONs (MAG-level for AMR, |
| species-level for VFDB virulence). |
| - **`checkpoints/`** β trained probe weights (each is tiny β `Linear(4096, 1)` |
| is 4097 params). Saved as PyTorch `.pt` state-dicts. |
| - `amr_binary_v1/linear/{{run_id}}/checkpoint.pt` β primary AMR probe |
| - `amr_binary_v1/attention/{{run_id}}/checkpoint.pt` β attention variant |
| - `amr_class5_v1/linear/{{CLASS}}/{{run_id}}/checkpoint.pt` β 5 per-class probes |
| - `vfdb_virulence_v1/linear/{{run_id}}/checkpoint.pt` β virulence probe |
| - **`plots/`** β score-distribution histograms, sanity plots, attention |
| visualisations (PNG). |
| - **`scores/`** β per-region / per-read raw probe logits (JSONL) β useful for |
| reformatting plots without retraining. |
| - **`summaries/`** β AUC + best-F1 + thresholds per probe (JSON). |
| - **`training_metrics/`** β full per-epoch val history (JSON). |
| - **`data/targeted_jsonl/`** β extraction outputs (gene + 2 kb flank |
| sequences, paired-with negatives, etc.) β the *inputs* to the embedding |
| pipeline. Big text files (~70 MB total). |
| - **`JOURNAL.md`** β chronological project log. |
| - **`HACKATHON_STATUS.md`** β high-level status doc. |
| |
| ## What's NOT here |
| |
| - Raw Evo 2 embeddings (the per-token layer-26 activations, several hundred |
| GB total). These live on Modal Volumes: |
| - `mgnify-embeddings-l26-lean` (5483 records Γ 37 MB = ~208 GB) |
| - `mgnify-embeddings-l26-vfdb` (~150 GB) |
| - `mgnify-embeddings-l26-human-viral` (~60 GB) |
| - A subset is mirrored as a public HF dataset: |
| `JG1310/mgnify-evo2-l26-full` β layer-26 npz for all MGnify-extracted regions. |
| - FASTQ files, raw input xlsx, etc. |
| |
| ## Loading a probe checkpoint |
| |
| ```python |
| import torch |
| import torch.nn as nn |
| |
| # AMR linear probe |
| sd = torch.load("checkpoints/amr_binary_v1/linear/<run_id>/checkpoint.pt", weights_only=True) |
| probe = nn.Linear(4096, 1) |
| probe.load_state_dict(sd) |
| probe.eval() |
| |
| # Apply to a per-token activation tensor [seq_len, 4096] |
| # (e.g. from a layer-26-extracted .npz file) |
| import numpy as np |
| d = np.load("some_region.npz", allow_pickle=False) |
| acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float() |
| with torch.no_grad(): |
| logits = probe(acts).squeeze(-1) # [seq_len] per-token logits |
| print(f"max-pool logit: {{logits.max():.2f}}") |
| print(f"mean-pool logit: {{logits.mean():.2f}}") |
| ``` |
| |
| ## Reading the manifests |
| |
| ```python |
| import json |
| m = json.load(open("manifests/amr_binary_v1.json")) |
| # m['region_split'][region_id] -> "train" / "val" / "test" |
| # m['labels_per_region'][region_id] -> 0 or 1 |
| # m['gene_coords'][region_id] -> [gene_start, gene_end, ext_start, ext_end, strand] |
| # m['pair_partner'][region_id] -> matched-pair region_id |
| ``` |
| |
| ## Headline results (test sets) |
| |
| | Probe | Eval | AUC | |
| |---|---|---| |
| | Linear, AMR-vs-MISC | 672 regions, MAG-level held out | 0.949 (region max-pool) | |
| | Attention, AMR-vs-MISC | 672 regions | 0.977 | |
| | Linear, AMR class-specific (top 5) | within-AMR class | 0.989 - 0.998 | |
| | Linear, VFDB virulence | 336 regions, species-level held out | 0.833 (region mean-pool) | |
| | Linear AMR probe β multi-org short reads (301 bp) | 1340 reads | 0.898 (per-read mean) / 0.921 (per-CDS) | |
| |
| See `JOURNAL.md` for the full story; `summaries/` for raw numbers. |
| """ |
|
|
|
|
| @app.local_entrypoint() |
| def main(repo_name: str = "mgnify-evo2-probes", private: bool = False): |
| print(f"[local] launching push to HF repo: {repo_name} (private={private})") |
| r = push_share.remote(repo_name=repo_name, private=private) |
| print("\n=== PUSHED ===") |
| print(f" URL: {r['repo_url']}") |
| print(f" files: {r['n_files']}") |
| print(f" size: {r['total_mb']:.1f} MB") |
| print(f" elapsed: {r['elapsed_s']:.0f} s") |
|
|