""" Push the probe artifacts (code + plots + manifests + scores + summaries + docs + tiny checkpoint weights) to a HuggingFace model repo for teammates. Excludes raw embeddings (huge, on Modal volumes) and FASTQ zips. Usage: modal run modal/push_probe_share.py::main modal run modal/push_probe_share.py::main --repo-name foo --private """ from __future__ import annotations import json import os import shutil import time from pathlib import Path import modal # Local source directories — mounted into the container LOCAL_PROBES = "/home/ror25cal/MGnify/probes" LOCAL_SCRIPTS = "/home/ror25cal/MGnify/scripts" LOCAL_SHARE = "/home/ror25cal/MGnify/share" LOCAL_MODAL = "/home/ror25cal/MGnify/modal" LOCAL_TARGETED_JSONL = "/home/ror25cal/MGnify/data/targeted_jsonl" LOCAL_JOURNAL = "/home/ror25cal/MGnify/JOURNAL.md" LOCAL_STATUS = "/home/ror25cal/MGnify/HACKATHON_STATUS.md" image = ( modal.Image.debian_slim() .pip_install("huggingface_hub>=0.25") .add_local_dir(LOCAL_PROBES, remote_path="/local/probes") .add_local_dir(LOCAL_SCRIPTS, remote_path="/local/scripts") .add_local_dir(LOCAL_SHARE, remote_path="/local/share") .add_local_dir(LOCAL_MODAL, remote_path="/local/modal") .add_local_dir(LOCAL_TARGETED_JSONL, remote_path="/local/data/targeted_jsonl") .add_local_file(LOCAL_JOURNAL, remote_path="/local/JOURNAL.md") .add_local_file(LOCAL_STATUS, remote_path="/local/HACKATHON_STATUS.md") ) results_vol = modal.Volume.from_name("mgnify-probe-results", create_if_missing=False) app = modal.App("mgnify-probe-share-push") @app.function( image=image, cpu=2, memory=4 * 1024, volumes={"/results": results_vol}, secrets=[modal.Secret.from_name("huggingface")], timeout=1800, ) def push_share(repo_name: str = "mgnify-evo2-probes", private: bool = False) -> dict: import glob import io from huggingface_hub import HfApi # ---- HF auth (workaround for swapped-key Modal secret) ---- token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not token: for k, v in os.environ.items(): if k.startswith("hf_") and v == "HF_TOKEN": token = k break if not token: raise RuntimeError("HF_TOKEN missing") api = HfApi(token=token) me = api.whoami() user = me["name"] repo_id = f"{user}/{repo_name}" print(f"[push] target: {repo_id} (private={private})") api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) # ---- Stage everything into /tmp/repo_stage ---- stage = "/tmp/repo_stage" if os.path.exists(stage): shutil.rmtree(stage) os.makedirs(stage) def copy_glob(src_glob, dst_dir): os.makedirs(dst_dir, exist_ok=True) for src in glob.glob(src_glob, recursive=True): if os.path.isfile(src): rel = os.path.basename(src) shutil.copy(src, f"{dst_dir}/{rel}") def copy_tree_filtered(src_root: str, dst_root: str, include_globs: list[str], exclude_substrings: list[str] = None): """Copy files matching include_globs, preserving relative path.""" os.makedirs(dst_root, exist_ok=True) exclude_substrings = exclude_substrings or [] for inc in include_globs: for src in glob.glob(f"{src_root}/{inc}", recursive=True): if not os.path.isfile(src): continue rel = os.path.relpath(src, src_root) if any(es in rel for es in exclude_substrings): continue dst = f"{dst_root}/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(src, dst) print("[stage] code …") copy_tree_filtered("/local/probes", f"{stage}/code/probes", ["*.py", "*.md"], exclude_substrings=["__pycache__"]) copy_tree_filtered("/local/scripts", f"{stage}/code/scripts", ["*.py", "*.md", "*.sh"], exclude_substrings=["__pycache__"]) copy_tree_filtered("/local/share", f"{stage}/code/share", ["*.py", "*.md"]) copy_tree_filtered("/local/modal", f"{stage}/code/modal", ["*.py", "*.md"], exclude_substrings=["__pycache__"]) print("[stage] manifests …") copy_glob("/local/probes/splits/*.json", f"{stage}/manifests") print("[stage] plots …") copy_glob("/local/probes/results/*.png", f"{stage}/plots") print("[stage] scores + summaries + training metrics …") copy_glob("/local/probes/results/*.scores.jsonl", f"{stage}/scores") copy_glob("/local/probes/results/*.summary.json", f"{stage}/summaries") copy_glob("/local/probes/results/*.json", f"{stage}/summaries") # also pull v1_*/metrics.json for f in glob.glob("/local/probes/results/v1_*/*.json"): rel = os.path.relpath(f, "/local/probes/results") dst = f"{stage}/training_metrics/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(f, dst) print("[stage] extraction JSONLs …") copy_tree_filtered("/local/data/targeted_jsonl", f"{stage}/data/targeted_jsonl", ["**/*.jsonl"]) print("[stage] docs …") shutil.copy("/local/JOURNAL.md", f"{stage}/JOURNAL.md") shutil.copy("/local/HACKATHON_STATUS.md", f"{stage}/HACKATHON_STATUS.md") print("[stage] probe checkpoints from Modal volume …") ckpt_dir = f"{stage}/checkpoints" os.makedirs(ckpt_dir, exist_ok=True) for ckpt_path in glob.glob("/results/**/checkpoint.pt", recursive=True): rel = os.path.relpath(ckpt_path, "/results") dst = f"{ckpt_dir}/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(ckpt_path, dst) # also include the per-epoch checkpoints (for VFDB) and the diff_of_means.npy for f in glob.glob("/results/**/checkpoint_epoch*.pt", recursive=True): rel = os.path.relpath(f, "/results") dst = f"{ckpt_dir}/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(f, dst) for f in glob.glob("/results/**/diff_of_means.npy", recursive=True): rel = os.path.relpath(f, "/results") dst = f"{ckpt_dir}/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(f, dst) for f in glob.glob("/results/**/metrics.json", recursive=True): rel = os.path.relpath(f, "/results") dst = f"{ckpt_dir}/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(f, dst) for f in glob.glob("/results/**/val_history.json", recursive=True): rel = os.path.relpath(f, "/results") dst = f"{ckpt_dir}/{rel}" os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(f, dst) # ---- README ---- print("[stage] README …") readme = README_TEMPLATE.format(repo_id=repo_id) Path(f"{stage}/README.md").write_text(readme) # ---- Inventory + size ---- total = 0 n_files = 0 for root, _, files in os.walk(stage): for f in files: total += os.path.getsize(os.path.join(root, f)) n_files += 1 print(f"[stage] {n_files} files, {total/1e6:.1f} MB") # ---- Push ---- print(f"[push] uploading {n_files} files to {repo_id} …") t0 = time.time() api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=stage, commit_message="Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs", ) elapsed = time.time() - t0 print(f"[push] done in {elapsed:.0f}s") return { "repo_url": f"https://huggingface.co/{repo_id}", "n_files": n_files, "total_mb": total / 1e6, "elapsed_s": elapsed, } README_TEMPLATE = """\ --- license: mit tags: - evo2 - bioinformatics - mgnify - amr - probe - interpretability --- # MGnify × Evo 2 layer-26 probes Linear and attention probes on the residual-stream of the Evo 2 7B-262k DNA foundation model (layer 26), trained to detect AMR genes, virulence factors, and AMR subclasses in MGnify metagenomes. Repo at {repo_id}. ## What's here - **`code/`** — all the Python / Modal pipeline code (extraction, embedding, probe training, plot scripts). - **`manifests/`** — train / val / test split JSONs (MAG-level for AMR, species-level for VFDB virulence). - **`checkpoints/`** — trained probe weights (each is tiny — `Linear(4096, 1)` is 4097 params). Saved as PyTorch `.pt` state-dicts. - `amr_binary_v1/linear/{{run_id}}/checkpoint.pt` — primary AMR probe - `amr_binary_v1/attention/{{run_id}}/checkpoint.pt` — attention variant - `amr_class5_v1/linear/{{CLASS}}/{{run_id}}/checkpoint.pt` — 5 per-class probes - `vfdb_virulence_v1/linear/{{run_id}}/checkpoint.pt` — virulence probe - **`plots/`** — score-distribution histograms, sanity plots, attention visualisations (PNG). - **`scores/`** — per-region / per-read raw probe logits (JSONL) — useful for reformatting plots without retraining. - **`summaries/`** — AUC + best-F1 + thresholds per probe (JSON). - **`training_metrics/`** — full per-epoch val history (JSON). - **`data/targeted_jsonl/`** — extraction outputs (gene + 2 kb flank sequences, paired-with negatives, etc.) — the *inputs* to the embedding pipeline. Big text files (~70 MB total). - **`JOURNAL.md`** — chronological project log. - **`HACKATHON_STATUS.md`** — high-level status doc. ## What's NOT here - Raw Evo 2 embeddings (the per-token layer-26 activations, several hundred GB total). These live on Modal Volumes: - `mgnify-embeddings-l26-lean` (5483 records × 37 MB = ~208 GB) - `mgnify-embeddings-l26-vfdb` (~150 GB) - `mgnify-embeddings-l26-human-viral` (~60 GB) - A subset is mirrored as a public HF dataset: `JG1310/mgnify-evo2-l26-full` — layer-26 npz for all MGnify-extracted regions. - FASTQ files, raw input xlsx, etc. ## Loading a probe checkpoint ```python import torch import torch.nn as nn # AMR linear probe sd = torch.load("checkpoints/amr_binary_v1/linear//checkpoint.pt", weights_only=True) probe = nn.Linear(4096, 1) probe.load_state_dict(sd) probe.eval() # Apply to a per-token activation tensor [seq_len, 4096] # (e.g. from a layer-26-extracted .npz file) import numpy as np d = np.load("some_region.npz", allow_pickle=False) acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float() with torch.no_grad(): logits = probe(acts).squeeze(-1) # [seq_len] per-token logits print(f"max-pool logit: {{logits.max():.2f}}") print(f"mean-pool logit: {{logits.mean():.2f}}") ``` ## Reading the manifests ```python import json m = json.load(open("manifests/amr_binary_v1.json")) # m['region_split'][region_id] -> "train" / "val" / "test" # m['labels_per_region'][region_id] -> 0 or 1 # m['gene_coords'][region_id] -> [gene_start, gene_end, ext_start, ext_end, strand] # m['pair_partner'][region_id] -> matched-pair region_id ``` ## Headline results (test sets) | Probe | Eval | AUC | |---|---|---| | Linear, AMR-vs-MISC | 672 regions, MAG-level held out | 0.949 (region max-pool) | | Attention, AMR-vs-MISC | 672 regions | 0.977 | | Linear, AMR class-specific (top 5) | within-AMR class | 0.989 - 0.998 | | Linear, VFDB virulence | 336 regions, species-level held out | 0.833 (region mean-pool) | | Linear AMR probe → multi-org short reads (301 bp) | 1340 reads | 0.898 (per-read mean) / 0.921 (per-CDS) | See `JOURNAL.md` for the full story; `summaries/` for raw numbers. """ @app.local_entrypoint() def main(repo_name: str = "mgnify-evo2-probes", private: bool = False): print(f"[local] launching push to HF repo: {repo_name} (private={private})") r = push_share.remote(repo_name=repo_name, private=private) print("\n=== PUSHED ===") print(f" URL: {r['repo_url']}") print(f" files: {r['n_files']}") print(f" size: {r['total_mb']:.1f} MB") print(f" elapsed: {r['elapsed_s']:.0f} s")