mgnify-evo2-probes / code /probes /train_amr_binary_v1.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""
Linear per-token probe for AMR-vs-matched-CDS binary classification.
Runs on Modal CPU. Reads activations from `mgnify-embeddings-l26-lean`.
Per-token labels:
- For AMR-positive records: token at position p is labelled 1 iff
ext_start + p ∈ [gene_start, gene_end].
- For matched-MISC records: all tokens labelled 0.
(See HACKATHON_STATUS.md and probes/splits/build_amr_binary_v1.py for split design.)
Usage:
modal run probes/train_amr_binary_v1.py::main --epochs 5
modal run probes/train_amr_binary_v1.py::dryrun_local # synthetic-data sanity check
"""
from __future__ import annotations
import json
import os
import random
import time
from pathlib import Path
from typing import Iterator
import modal
# ---------------------------------------------------------------------------
# Constants / config
# ---------------------------------------------------------------------------
MANIFEST_PATH_LOCAL = "/home/ror25cal/MGnify/probes/splits/amr_binary_v1.json"
MANIFEST_PATH_REMOTE = "/manifest/amr_binary_v1.json"
RESULTS_DIR_LOCAL = "/home/ror25cal/MGnify/probes/results/amr_binary_v1/linear"
HIDDEN = 4096
SEED = 42
# ---------------------------------------------------------------------------
# Modal app + image. Pure CPU — no GPU, no Evo 2 needed.
# ---------------------------------------------------------------------------
image = (
modal.Image.debian_slim()
.pip_install("numpy", "torch>=2.0", "scikit-learn>=1.3", "matplotlib")
)
l26_vol = modal.Volume.from_name("mgnify-embeddings-l26-lean", create_if_missing=False)
results_vol = modal.Volume.from_name("mgnify-probe-results", create_if_missing=True)
manifest_vol = modal.Volume.from_name("mgnify-probe-manifests", create_if_missing=True)
app = modal.App("mgnify-amr-probe-v1")
# ---------------------------------------------------------------------------
# Per-token label computation
# ---------------------------------------------------------------------------
def per_token_labels_from_record(
region_id: str,
region_label: int, # 1 if AMR-positive, 0 if MISC
gene_coords: list, # [gene_start, gene_end, ext_start, ext_end, strand]
seq_len: int,
) -> "np.ndarray":
"""Return an int8 array of shape [seq_len], 1 where token position is
inside the gene CDS for AMR-positive records, else 0."""
import numpy as np
if region_label == 0:
return np.zeros(seq_len, dtype=np.int8)
gene_start, gene_end, ext_start, ext_end, _strand = gene_coords
# Positions in the extracted sequence are [ext_start .. ext_end] inclusive
# in genomic coordinates. token p of the npz corresponds to genomic
# coordinate (ext_start + p). We don't reverse-complement; tokens in the
# gene body are those whose genomic coord is in [gene_start, gene_end].
labels = np.zeros(seq_len, dtype=np.int8)
in_gene_start = max(0, gene_start - ext_start)
in_gene_end = min(seq_len, gene_end - ext_start + 1)
if in_gene_end > in_gene_start:
labels[in_gene_start:in_gene_end] = 1
return labels
# ---------------------------------------------------------------------------
# Load one .npz: returns (acts_fp32 [seq_len, 4096], labels [seq_len]).
# ---------------------------------------------------------------------------
def load_region(npz_root: str, region_id: str, region_label: int, gene_coords: list):
import numpy as np
import torch
label_folder = "AMR" if region_label == 1 else "MISC"
mag_id = region_id.rsplit("_", 2)[0] # MGYG..._00123_AMR -> MGYG...
npz_path = f"{npz_root}/{label_folder}/{mag_id}/{region_id}.npz"
d = np.load(npz_path, allow_pickle=False)
acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float().numpy()
seq_len = acts.shape[0]
labels = per_token_labels_from_record(region_id, region_label, gene_coords, seq_len)
return acts, labels
# ---------------------------------------------------------------------------
# Streaming iterator: yield (acts_batch [B, 4096], labels_batch [B]) for the
# given split. Each "batch" is one region's worth of tokens (variable B).
# ---------------------------------------------------------------------------
def iter_split(
manifest: dict,
split: str,
npz_root: str,
shuffle: bool = True,
seed: int = SEED,
) -> Iterator[tuple]:
region_ids = [r for r, s in manifest["region_split"].items() if s == split]
if shuffle:
rng = random.Random(seed)
rng.shuffle(region_ids)
labels_per_region = manifest["labels_per_region"]
gene_coords = manifest["gene_coords"]
for rid in region_ids:
try:
acts, labels = load_region(
npz_root, rid,
region_label=int(labels_per_region[rid]),
gene_coords=gene_coords[rid],
)
except FileNotFoundError:
print(f" WARN: missing {rid}; skipping")
continue
yield rid, acts, labels
# ---------------------------------------------------------------------------
# Linear probe
# ---------------------------------------------------------------------------
def make_probe(hidden: int = HIDDEN):
import torch.nn as nn
return nn.Linear(hidden, 1)
# ---------------------------------------------------------------------------
# Train one epoch over the train split.
# ---------------------------------------------------------------------------
def train_one_epoch(probe, optimizer, manifest, npz_root, pos_weight, device, max_regions=None):
import torch
import torch.nn as nn
probe.train()
bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(float(pos_weight), device=device))
total_loss = 0.0
total_tokens = 0
n_pos_seen = 0
n_regions = 0
t0 = time.time()
for rid, acts, labels in iter_split(manifest, "train", npz_root, shuffle=True, seed=SEED + 7):
n_regions += 1
if max_regions is not None and n_regions > max_regions:
break
x = torch.from_numpy(acts).to(device) # [seq_len, 4096]
y = torch.from_numpy(labels).float().to(device) # [seq_len]
logits = probe(x).squeeze(-1) # [seq_len]
loss = bce(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * x.shape[0]
total_tokens += x.shape[0]
n_pos_seen += int(y.sum().item())
if n_regions % 200 == 0:
print(f" epoch progress: {n_regions} regions, "
f"avg loss so far {total_loss/total_tokens:.4f}, "
f"pos rate {n_pos_seen/total_tokens:.4f}")
elapsed = time.time() - t0
print(f" epoch done: {n_regions} regions, {total_tokens} tokens, "
f"avg loss {total_loss/total_tokens:.4f}, {elapsed:.0f}s")
return total_loss / max(total_tokens, 1)
# ---------------------------------------------------------------------------
# Eval on val/test split. Returns metric dict.
# ---------------------------------------------------------------------------
def evaluate(probe, manifest, npz_root, split, device):
import torch
import numpy as np
from sklearn.metrics import (
roc_auc_score, average_precision_score, f1_score, precision_recall_curve,
)
probe.eval()
all_logits = []
all_labels = []
all_region_ids = []
all_region_labels = []
with torch.no_grad():
for rid, acts, labels in iter_split(manifest, split, npz_root, shuffle=False):
x = torch.from_numpy(acts).to(device)
logits = probe(x).squeeze(-1).cpu().numpy()
all_logits.append(logits)
all_labels.append(labels)
all_region_ids.append(rid)
all_region_labels.append(int(manifest["labels_per_region"][rid]))
logits = np.concatenate(all_logits)
labels = np.concatenate(all_labels)
# Per-token metrics
token_auc = roc_auc_score(labels, logits)
token_pr_auc = average_precision_score(labels, logits)
# Find best F1 threshold
precision, recall, thresholds = precision_recall_curve(labels, logits)
f1s = 2 * precision * recall / np.maximum(precision + recall, 1e-9)
best_idx = int(np.argmax(f1s))
best_thresh = float(thresholds[min(best_idx, len(thresholds) - 1)])
token_f1_best = float(f1s[best_idx])
# Per-region metrics: max-over-tokens (a region is predicted positive if any
# token's logit clears threshold) AND mean-over-tokens. Standard ways to
# aggregate per-token to per-region.
region_max_logit = []
region_mean_logit = []
cursor = 0
for arr in all_logits:
n = arr.shape[0]
region_max_logit.append(float(np.max(arr)))
region_mean_logit.append(float(np.mean(arr)))
cursor += n
region_labels_np = np.array(all_region_labels)
region_max_auc = roc_auc_score(region_labels_np, region_max_logit)
region_mean_auc = roc_auc_score(region_labels_np, region_mean_logit)
return {
"split": split,
"n_regions": len(all_region_ids),
"n_tokens": int(labels.shape[0]),
"n_positive_tokens": int(labels.sum()),
"token_roc_auc": float(token_auc),
"token_pr_auc": float(token_pr_auc),
"token_best_f1": token_f1_best,
"token_best_threshold": best_thresh,
"region_max_pool_auc": float(region_max_auc),
"region_mean_pool_auc": float(region_mean_auc),
}
# ---------------------------------------------------------------------------
# Modal training entrypoint
# ---------------------------------------------------------------------------
@app.function(
image=image,
cpu=4,
memory=32 * 1024, # 32 GB
volumes={
"/data": l26_vol,
"/results": results_vol,
"/manifest": manifest_vol,
},
timeout=7200,
)
def train_remote(epochs: int = 5, lr: float = 1e-3, pos_weight: float = 20.0, run_id: str = ""):
import torch
import numpy as np
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
print(f"[probe] reading manifest from {MANIFEST_PATH_REMOTE}")
manifest = json.loads(Path(MANIFEST_PATH_REMOTE).read_text())
device = torch.device("cpu")
probe = make_probe().to(device)
optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
print(f"[probe] linear probe: in={HIDDEN}, out=1, params={sum(p.numel() for p in probe.parameters())}")
val_metrics_history = []
best_val_auc = -1.0
best_epoch = -1
best_state = None
for epoch in range(epochs):
print(f"\n=== EPOCH {epoch + 1} / {epochs} ===")
train_loss = train_one_epoch(probe, optimizer, manifest, "/data", pos_weight, device)
print(f" evaluating on val split...")
val_metrics = evaluate(probe, manifest, "/data", "val", device)
val_metrics["epoch"] = epoch + 1
val_metrics["train_loss"] = train_loss
val_metrics_history.append(val_metrics)
print(f" val token AUC: {val_metrics['token_roc_auc']:.4f}, "
f"region max-pool AUC: {val_metrics['region_max_pool_auc']:.4f}")
if val_metrics["token_roc_auc"] > best_val_auc:
best_val_auc = val_metrics["token_roc_auc"]
best_epoch = epoch + 1
best_state = {k: v.cpu().clone() for k, v in probe.state_dict().items()}
print(f"\n[probe] best val token AUC = {best_val_auc:.4f} at epoch {best_epoch}")
print("[probe] running final test eval with best epoch's weights")
if best_state is not None:
probe.load_state_dict(best_state)
test_metrics = evaluate(probe, manifest, "/data", "test", device)
# Save results
run_id = run_id or time.strftime("%Y%m%d_%H%M%S")
out_dir = f"/results/amr_binary_v1/linear/{run_id}"
os.makedirs(out_dir, exist_ok=True)
torch.save(best_state, f"{out_dir}/checkpoint.pt")
Path(f"{out_dir}/metrics.json").write_text(json.dumps({
"manifest": "amr_binary_v1",
"model": "linear",
"hyperparameters": {"epochs": epochs, "lr": lr, "pos_weight": pos_weight, "seed": SEED},
"best_epoch": best_epoch,
"val_history": val_metrics_history,
"test": test_metrics,
}, indent=2))
results_vol.commit()
print(f"\n[probe] results written to {out_dir}/metrics.json")
print(f"\n=== FINAL TEST METRICS ===")
for k, v in test_metrics.items():
print(f" {k}: {v}")
return {"run_id": run_id, "best_epoch": best_epoch, "test_metrics": test_metrics}
# ---------------------------------------------------------------------------
# Helper: upload manifest to its own volume (one-time per manifest version).
# ---------------------------------------------------------------------------
@app.function(
image=image,
cpu=1,
volumes={"/manifest": manifest_vol},
timeout=300,
)
def upload_manifest_to_volume(manifest_text: str, name: str = "amr_binary_v1.json"):
"""Stuff the manifest JSON onto the volume so the trainer can read it."""
out = Path("/manifest") / name
out.write_text(manifest_text)
manifest_vol.commit()
return {"path": str(out), "size_kb": out.stat().st_size / 1024}
# ---------------------------------------------------------------------------
# Local entrypoint: upload manifest + kick off training.
# ---------------------------------------------------------------------------
@app.local_entrypoint()
def main(epochs: int = 5, lr: float = 1e-3, pos_weight: float = 20.0):
"""Upload the manifest if needed, then train.
modal run probes/train_amr_binary_v1.py::main
modal run probes/train_amr_binary_v1.py::main --epochs 10 --lr 5e-4
"""
text = Path(MANIFEST_PATH_LOCAL).read_text()
print("[local] uploading manifest to Modal volume...")
print(upload_manifest_to_volume.remote(text))
print("[local] starting training...")
r = train_remote.remote(epochs=epochs, lr=lr, pos_weight=pos_weight)
print("\n=== RESULT ===")
print(json.dumps(r, indent=2))
# ---------------------------------------------------------------------------
# Synthetic-data dry run on local CPU. Verifies all logic paths without any
# network or Modal calls. ~10 seconds.
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Sanity-check plot: per-token probe logits for N example regions.
# Loads the latest probe checkpoint, picks 5 AMR positives + their 5 matched
# negatives from val/test (so we plot truly held-out examples), runs probe
# forward, plots per-token logit vs position with the CDS interval shaded.
# Returns PNG bytes.
# ---------------------------------------------------------------------------
@app.function(
image=image,
cpu=2,
memory=8 * 1024,
volumes={
"/data": l26_vol,
"/results": results_vol,
"/manifest": manifest_vol,
},
timeout=900,
)
def plot_probe_logits(run_id: str = "", n_examples: int = 5) -> bytes:
"""Generate sanity-check plot of per-token logits. Returns PNG as bytes."""
import io
import os
import json
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# Resolve checkpoint: latest run if none specified
base = "/results/amr_binary_v1/linear"
if not run_id:
run_ids = sorted(os.listdir(base))
if not run_ids:
raise RuntimeError(f"no runs found under {base}")
run_id = run_ids[-1]
ckpt_path = f"{base}/{run_id}/checkpoint.pt"
print(f"[plot] loading checkpoint: {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
probe = nn.Linear(HIDDEN, 1)
probe.load_state_dict(state_dict)
probe.eval()
manifest = json.loads(open(MANIFEST_PATH_REMOTE).read())
rng = np.random.default_rng(SEED)
# 5 AMR positives from val/test, with their paired negatives
pos_ids = [
rid for rid, lbl in manifest["labels_per_region"].items()
if lbl == 1 and manifest["region_split"][rid] in ("val", "test")
]
rng.shuffle(pos_ids)
selected_pos = pos_ids[:n_examples]
selected_neg = [manifest["pair_partner"][p] for p in selected_pos]
# Compute logits for each
examples = []
for rid in selected_pos + selected_neg:
gc = manifest["gene_coords"][rid]
rl = int(manifest["labels_per_region"][rid])
acts, labels = load_region("/data", rid, rl, gc)
with torch.no_grad():
logits = probe(torch.from_numpy(acts)).squeeze(-1).numpy()
examples.append({"rid": rid, "rl": rl, "gc": gc, "logits": logits, "labels": labels})
# plot — 2 cols (positive | negative), n_examples rows. Pitch styling.
fig, axes = plt.subplots(n_examples, 2, figsize=(14, 2.4 * n_examples), sharex=False)
if n_examples == 1:
axes = np.array([axes])
for col, half in enumerate([selected_pos, selected_neg]):
for row, rid in enumerate(half):
ex = next(e for e in examples if e["rid"] == rid)
ax = axes[row, col]
logits = ex["logits"]
positions = np.arange(len(logits))
ax.plot(positions, logits, lw=0.6, color="black")
ax.axhline(0, color="grey", lw=0.6, ls="--")
gs, ge, es, ee, strand = ex["gc"]
cds0 = max(0, gs - es)
cds1 = min(len(logits), ge - es + 1)
shade_color = "tab:green" if ex["rl"] == 1 else "tab:red"
ax.axvspan(cds0, cds1, alpha=0.22, color=shade_color)
ax.tick_params(labelsize=12)
# Only show axis labels on the outer cells to reduce clutter
if row == n_examples - 1:
ax.set_xlabel("token position", fontsize=14)
if col == 0:
ax.set_ylabel("logit", fontsize=14)
# Column headers (top of each column)
axes[0, 0].set_title("AMR positive (CDS shaded green)", fontsize=14, fontweight="bold")
axes[0, 1].set_title("matched negative (CDS shaded red)", fontsize=14, fontweight="bold")
fig.suptitle("Linear probe on AMR — per-token logits", fontsize=18, fontweight="bold")
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
return buf.getvalue()
# ---------------------------------------------------------------------------
# Hyperplane-projection 1D plots: mean-pool and max-pool per-region logits
# for the v1 AMR linear probe on the test split. Shows pos/neg distributions
# on the same axis the classifier uses, with decision boundaries marked.
# ---------------------------------------------------------------------------
@app.function(
image=image,
cpu=2,
memory=8 * 1024,
volumes={"/data": l26_vol, "/results": results_vol, "/manifest": manifest_vol},
timeout=900,
)
def plot_score_distributions(run_id: str = "", split: str = "test") -> dict:
"""1D distribution plots of per-region max-pool and mean-pool logits.
Returns dict with PNG bytes AND raw scores so reformatting doesn't require
a re-run."""
import io
import os
import json
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, precision_recall_curve
base = "/results/amr_binary_v1/linear"
if not run_id:
run_id = sorted(os.listdir(base))[-1]
state_dict = torch.load(f"{base}/{run_id}/checkpoint.pt", map_location="cpu", weights_only=True)
probe = make_probe()
probe.load_state_dict(state_dict)
probe.eval()
print(f"[plot] loaded probe from run {run_id}")
manifest = json.loads(Path(MANIFEST_PATH_REMOTE).read_text())
region_ids: list[str] = []
region_max = []
region_mean = []
region_labels = []
with torch.no_grad():
for rid, acts, _labels in iter_split(manifest, split, "/data", shuffle=False):
x = torch.from_numpy(acts)
logits = probe(x).squeeze(-1).numpy()
region_ids.append(rid)
region_max.append(float(np.max(logits)))
region_mean.append(float(np.mean(logits)))
region_labels.append(int(manifest["labels_per_region"][rid]))
region_max = np.array(region_max)
region_mean = np.array(region_mean)
region_labels = np.array(region_labels)
# AUCs + best-F1 thresholds for each aggregation
def stats(scores):
auc = roc_auc_score(region_labels, scores)
prec, rec, thr = precision_recall_curve(region_labels, scores)
f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9)
idx = int(np.argmax(f1))
best_thr = float(thr[min(idx, len(thr) - 1)])
return auc, best_thr, float(f1[idx])
auc_max, t_max, f1_max = stats(region_max)
auc_mean, t_mean, f1_mean = stats(region_mean)
print(f"[plot] max-pool: AUC={auc_max:.4f} best-F1 thr={t_max:.3f} F1={f1_max:.3f}")
print(f"[plot] mean-pool: AUC={auc_mean:.4f} best-F1 thr={t_mean:.3f} F1={f1_mean:.3f}")
# Render — 2 panels side-by-side
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, scores, name, auc, best_thr, f1 in [
(axes[0], region_max, "max-pool", auc_max, t_max, f1_max),
(axes[1], region_mean, "mean-pool", auc_mean, t_mean, f1_mean),
]:
pos = scores[region_labels == 1]
neg = scores[region_labels == 0]
# KDE / histogram per class
bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 60)
ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True)
ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True)
# Rug plots at the very bottom
rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6)
ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6)
# Decision boundaries
ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})")
ax.set_xlabel("per-region probe logit (= w · h_pooled + b)", fontsize=10)
ax.set_ylabel("density", fontsize=10)
ax.set_title(f"{name} • AUC={auc:.4f} • best-F1={f1:.3f}", fontsize=11)
ax.legend(fontsize=8, loc="upper left")
ax.grid(True, alpha=0.2)
fig.suptitle(
f"v1 AMR linear probe — per-region score distributions ({split} split, "
f"{int((region_labels==1).sum())} pos + {int((region_labels==0).sum())} neg, "
f"MAG-level held-out)\n"
"1D projection onto the probe's learned direction. Each region is one number "
"= the probe's logit summarised by max- or mean-pool over its tokens.",
fontsize=11,
)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
# Raw per-region scores so the plot can be reformatted without re-running.
raw = [
{"region_id": rid, "label": int(lab),
"max_logit": float(mx), "mean_logit": float(mn)}
for rid, lab, mx, mn in zip(region_ids, region_labels, region_max, region_mean)
]
summary = {
"split": split,
"run_id": run_id,
"n_pos": int((region_labels == 1).sum()),
"n_neg": int((region_labels == 0).sum()),
"max_pool": {"auc": auc_max, "best_f1_threshold": t_max, "best_f1": f1_max},
"mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean},
}
return {"png": buf.getvalue(), "scores": raw, "summary": summary}
@app.local_entrypoint()
def make_score_dist_plot(
run_id: str = "",
split: str = "test",
out_path: str = "/home/ror25cal/MGnify/probes/results/amr_binary_v1_score_distributions.png",
):
"""Generate the per-region max/mean-pool logit distribution plot.
Also persists the raw per-region scores next to the PNG (JSONL + summary JSON)."""
result = plot_score_distributions.remote(run_id=run_id, split=split)
out_png = Path(out_path)
out_png.parent.mkdir(parents=True, exist_ok=True)
out_png.write_bytes(result["png"])
print(f"saved {len(result['png'])/1024:.1f} KB to {out_png}")
out_jsonl = out_png.with_suffix(".scores.jsonl")
out_summary = out_png.with_suffix(".summary.json")
out_jsonl.write_text("\n".join(json.dumps(r) for r in result["scores"]) + "\n")
out_summary.write_text(json.dumps(result["summary"], indent=2))
print(f"saved {len(result['scores'])} per-region scores to {out_jsonl}")
print(f"saved aggregate summary to {out_summary}")
@app.local_entrypoint()
def make_sanity_plot(
run_id: str = "",
out_path: str = "/home/ror25cal/MGnify/probes/results/amr_binary_v1_sanity_plot.png",
n_examples: int = 5,
):
"""Produce the sanity-check PNG of per-token logits and save locally.
modal run probes/train_amr_binary_v1.py::make_sanity_plot
"""
print(f"[local] generating sanity plot for run_id={run_id or 'LATEST'}")
img_bytes = plot_probe_logits.remote(run_id=run_id, n_examples=n_examples)
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
Path(out_path).write_bytes(img_bytes)
print(f"[local] saved {len(img_bytes)/1024:.1f} KB to {out_path}")
@app.local_entrypoint()
def dryrun_local():
"""Verify the per-token labelling, the linear probe forward/backward, the
metric calculation. Uses synthetic data — no network, no Modal."""
import numpy as np
import torch
import torch.nn as nn
print("=== DRY RUN (synthetic data) ===")
# Synthetic per-token labels from a fake AMR record
fake_coords = [2000, 2624, 0, 4624, "+"] # gene_start, gene_end, ext_start, ext_end, strand
labels = per_token_labels_from_record("FAKE_AMR", region_label=1, gene_coords=fake_coords, seq_len=4624)
n_pos = int(labels.sum())
print(f"per-token labelling: {n_pos} positive tokens out of 4624 "
f"(expect 625 for gene 2000..2624 inclusive)")
assert n_pos == 625, f"expected 625, got {n_pos}"
# Edge case: clamping
edge_labels = per_token_labels_from_record("FAKE", 1, [-100, 200, 0, 1000, "+"], seq_len=1000)
print(f"edge: gene_start before ext_start, got {int(edge_labels.sum())} positives (expect ~201)")
assert edge_labels[0] == 1
# MISC always all-zero
misc_labels = per_token_labels_from_record("FAKE_MISC", 0, [0, 100, 0, 1000, "+"], seq_len=1000)
assert misc_labels.sum() == 0
# Linear probe forward/backward
probe = make_probe().to("cpu")
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0))
x = torch.randn(4624, 4096)
y = torch.from_numpy(labels).float()
logits = probe(x).squeeze(-1)
loss = bce(logits, y)
print(f"forward pass loss = {loss.item():.4f}")
loss.backward()
optimizer.step()
print("backward + optimizer step OK")
# Synthetic eval: build a "split" with two regions, compute metrics
fake_manifest = {
"region_split": {"R1": "val", "R2": "val"},
"labels_per_region": {"R1": 1, "R2": 0},
"gene_coords": {
"R1": [2000, 2624, 0, 4624, "+"],
"R2": [2000, 2624, 0, 4624, "+"], # not used since R2 is MISC
},
}
# We can't actually load .npz files in dry-run, so test evaluate() with
# patched load_region. Skip — only run on Modal where data is present.
print("\nDRY RUN OK — logic paths working. Run `modal run main` for the real thing.")