| """ |
| Linear per-token probe for AMR-vs-matched-CDS binary classification. |
| Runs on Modal CPU. Reads activations from `mgnify-embeddings-l26-lean`. |
| |
| Per-token labels: |
| - For AMR-positive records: token at position p is labelled 1 iff |
| ext_start + p ∈ [gene_start, gene_end]. |
| - For matched-MISC records: all tokens labelled 0. |
| (See HACKATHON_STATUS.md and probes/splits/build_amr_binary_v1.py for split design.) |
| |
| Usage: |
| modal run probes/train_amr_binary_v1.py::main --epochs 5 |
| modal run probes/train_amr_binary_v1.py::dryrun_local # synthetic-data sanity check |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| import time |
| from pathlib import Path |
| from typing import Iterator |
|
|
| import modal |
|
|
|
|
| |
| |
| |
| MANIFEST_PATH_LOCAL = "/home/ror25cal/MGnify/probes/splits/amr_binary_v1.json" |
| MANIFEST_PATH_REMOTE = "/manifest/amr_binary_v1.json" |
| RESULTS_DIR_LOCAL = "/home/ror25cal/MGnify/probes/results/amr_binary_v1/linear" |
| HIDDEN = 4096 |
| SEED = 42 |
|
|
|
|
| |
| |
| |
| image = ( |
| modal.Image.debian_slim() |
| .pip_install("numpy", "torch>=2.0", "scikit-learn>=1.3", "matplotlib") |
| ) |
|
|
| l26_vol = modal.Volume.from_name("mgnify-embeddings-l26-lean", create_if_missing=False) |
| results_vol = modal.Volume.from_name("mgnify-probe-results", create_if_missing=True) |
| manifest_vol = modal.Volume.from_name("mgnify-probe-manifests", create_if_missing=True) |
|
|
| app = modal.App("mgnify-amr-probe-v1") |
|
|
|
|
| |
| |
| |
| def per_token_labels_from_record( |
| region_id: str, |
| region_label: int, |
| gene_coords: list, |
| seq_len: int, |
| ) -> "np.ndarray": |
| """Return an int8 array of shape [seq_len], 1 where token position is |
| inside the gene CDS for AMR-positive records, else 0.""" |
| import numpy as np |
| if region_label == 0: |
| return np.zeros(seq_len, dtype=np.int8) |
| gene_start, gene_end, ext_start, ext_end, _strand = gene_coords |
| |
| |
| |
| |
| labels = np.zeros(seq_len, dtype=np.int8) |
| in_gene_start = max(0, gene_start - ext_start) |
| in_gene_end = min(seq_len, gene_end - ext_start + 1) |
| if in_gene_end > in_gene_start: |
| labels[in_gene_start:in_gene_end] = 1 |
| return labels |
|
|
|
|
| |
| |
| |
| def load_region(npz_root: str, region_id: str, region_label: int, gene_coords: list): |
| import numpy as np |
| import torch |
|
|
| label_folder = "AMR" if region_label == 1 else "MISC" |
| mag_id = region_id.rsplit("_", 2)[0] |
| npz_path = f"{npz_root}/{label_folder}/{mag_id}/{region_id}.npz" |
| d = np.load(npz_path, allow_pickle=False) |
| acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float().numpy() |
| seq_len = acts.shape[0] |
| labels = per_token_labels_from_record(region_id, region_label, gene_coords, seq_len) |
| return acts, labels |
|
|
|
|
| |
| |
| |
| |
| def iter_split( |
| manifest: dict, |
| split: str, |
| npz_root: str, |
| shuffle: bool = True, |
| seed: int = SEED, |
| ) -> Iterator[tuple]: |
| region_ids = [r for r, s in manifest["region_split"].items() if s == split] |
| if shuffle: |
| rng = random.Random(seed) |
| rng.shuffle(region_ids) |
| labels_per_region = manifest["labels_per_region"] |
| gene_coords = manifest["gene_coords"] |
| for rid in region_ids: |
| try: |
| acts, labels = load_region( |
| npz_root, rid, |
| region_label=int(labels_per_region[rid]), |
| gene_coords=gene_coords[rid], |
| ) |
| except FileNotFoundError: |
| print(f" WARN: missing {rid}; skipping") |
| continue |
| yield rid, acts, labels |
|
|
|
|
| |
| |
| |
| def make_probe(hidden: int = HIDDEN): |
| import torch.nn as nn |
| return nn.Linear(hidden, 1) |
|
|
|
|
| |
| |
| |
| def train_one_epoch(probe, optimizer, manifest, npz_root, pos_weight, device, max_regions=None): |
| import torch |
| import torch.nn as nn |
|
|
| probe.train() |
| bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(float(pos_weight), device=device)) |
| total_loss = 0.0 |
| total_tokens = 0 |
| n_pos_seen = 0 |
| n_regions = 0 |
| t0 = time.time() |
| for rid, acts, labels in iter_split(manifest, "train", npz_root, shuffle=True, seed=SEED + 7): |
| n_regions += 1 |
| if max_regions is not None and n_regions > max_regions: |
| break |
| x = torch.from_numpy(acts).to(device) |
| y = torch.from_numpy(labels).float().to(device) |
| logits = probe(x).squeeze(-1) |
| loss = bce(logits, y) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() * x.shape[0] |
| total_tokens += x.shape[0] |
| n_pos_seen += int(y.sum().item()) |
| if n_regions % 200 == 0: |
| print(f" epoch progress: {n_regions} regions, " |
| f"avg loss so far {total_loss/total_tokens:.4f}, " |
| f"pos rate {n_pos_seen/total_tokens:.4f}") |
| elapsed = time.time() - t0 |
| print(f" epoch done: {n_regions} regions, {total_tokens} tokens, " |
| f"avg loss {total_loss/total_tokens:.4f}, {elapsed:.0f}s") |
| return total_loss / max(total_tokens, 1) |
|
|
|
|
| |
| |
| |
| def evaluate(probe, manifest, npz_root, split, device): |
| import torch |
| import numpy as np |
| from sklearn.metrics import ( |
| roc_auc_score, average_precision_score, f1_score, precision_recall_curve, |
| ) |
|
|
| probe.eval() |
| all_logits = [] |
| all_labels = [] |
| all_region_ids = [] |
| all_region_labels = [] |
|
|
| with torch.no_grad(): |
| for rid, acts, labels in iter_split(manifest, split, npz_root, shuffle=False): |
| x = torch.from_numpy(acts).to(device) |
| logits = probe(x).squeeze(-1).cpu().numpy() |
| all_logits.append(logits) |
| all_labels.append(labels) |
| all_region_ids.append(rid) |
| all_region_labels.append(int(manifest["labels_per_region"][rid])) |
|
|
| logits = np.concatenate(all_logits) |
| labels = np.concatenate(all_labels) |
|
|
| |
| token_auc = roc_auc_score(labels, logits) |
| token_pr_auc = average_precision_score(labels, logits) |
| |
| precision, recall, thresholds = precision_recall_curve(labels, logits) |
| f1s = 2 * precision * recall / np.maximum(precision + recall, 1e-9) |
| best_idx = int(np.argmax(f1s)) |
| best_thresh = float(thresholds[min(best_idx, len(thresholds) - 1)]) |
| token_f1_best = float(f1s[best_idx]) |
|
|
| |
| |
| |
| region_max_logit = [] |
| region_mean_logit = [] |
| cursor = 0 |
| for arr in all_logits: |
| n = arr.shape[0] |
| region_max_logit.append(float(np.max(arr))) |
| region_mean_logit.append(float(np.mean(arr))) |
| cursor += n |
| region_labels_np = np.array(all_region_labels) |
| region_max_auc = roc_auc_score(region_labels_np, region_max_logit) |
| region_mean_auc = roc_auc_score(region_labels_np, region_mean_logit) |
|
|
| return { |
| "split": split, |
| "n_regions": len(all_region_ids), |
| "n_tokens": int(labels.shape[0]), |
| "n_positive_tokens": int(labels.sum()), |
| "token_roc_auc": float(token_auc), |
| "token_pr_auc": float(token_pr_auc), |
| "token_best_f1": token_f1_best, |
| "token_best_threshold": best_thresh, |
| "region_max_pool_auc": float(region_max_auc), |
| "region_mean_pool_auc": float(region_mean_auc), |
| } |
|
|
|
|
| |
| |
| |
| @app.function( |
| image=image, |
| cpu=4, |
| memory=32 * 1024, |
| volumes={ |
| "/data": l26_vol, |
| "/results": results_vol, |
| "/manifest": manifest_vol, |
| }, |
| timeout=7200, |
| ) |
| def train_remote(epochs: int = 5, lr: float = 1e-3, pos_weight: float = 20.0, run_id: str = ""): |
| import torch |
| import numpy as np |
|
|
| torch.manual_seed(SEED) |
| np.random.seed(SEED) |
| random.seed(SEED) |
|
|
| print(f"[probe] reading manifest from {MANIFEST_PATH_REMOTE}") |
| manifest = json.loads(Path(MANIFEST_PATH_REMOTE).read_text()) |
|
|
| device = torch.device("cpu") |
| probe = make_probe().to(device) |
| optimizer = torch.optim.Adam(probe.parameters(), lr=lr) |
| print(f"[probe] linear probe: in={HIDDEN}, out=1, params={sum(p.numel() for p in probe.parameters())}") |
|
|
| val_metrics_history = [] |
| best_val_auc = -1.0 |
| best_epoch = -1 |
| best_state = None |
|
|
| for epoch in range(epochs): |
| print(f"\n=== EPOCH {epoch + 1} / {epochs} ===") |
| train_loss = train_one_epoch(probe, optimizer, manifest, "/data", pos_weight, device) |
| print(f" evaluating on val split...") |
| val_metrics = evaluate(probe, manifest, "/data", "val", device) |
| val_metrics["epoch"] = epoch + 1 |
| val_metrics["train_loss"] = train_loss |
| val_metrics_history.append(val_metrics) |
| print(f" val token AUC: {val_metrics['token_roc_auc']:.4f}, " |
| f"region max-pool AUC: {val_metrics['region_max_pool_auc']:.4f}") |
| if val_metrics["token_roc_auc"] > best_val_auc: |
| best_val_auc = val_metrics["token_roc_auc"] |
| best_epoch = epoch + 1 |
| best_state = {k: v.cpu().clone() for k, v in probe.state_dict().items()} |
|
|
| print(f"\n[probe] best val token AUC = {best_val_auc:.4f} at epoch {best_epoch}") |
| print("[probe] running final test eval with best epoch's weights") |
| if best_state is not None: |
| probe.load_state_dict(best_state) |
| test_metrics = evaluate(probe, manifest, "/data", "test", device) |
|
|
| |
| run_id = run_id or time.strftime("%Y%m%d_%H%M%S") |
| out_dir = f"/results/amr_binary_v1/linear/{run_id}" |
| os.makedirs(out_dir, exist_ok=True) |
| torch.save(best_state, f"{out_dir}/checkpoint.pt") |
| Path(f"{out_dir}/metrics.json").write_text(json.dumps({ |
| "manifest": "amr_binary_v1", |
| "model": "linear", |
| "hyperparameters": {"epochs": epochs, "lr": lr, "pos_weight": pos_weight, "seed": SEED}, |
| "best_epoch": best_epoch, |
| "val_history": val_metrics_history, |
| "test": test_metrics, |
| }, indent=2)) |
| results_vol.commit() |
|
|
| print(f"\n[probe] results written to {out_dir}/metrics.json") |
| print(f"\n=== FINAL TEST METRICS ===") |
| for k, v in test_metrics.items(): |
| print(f" {k}: {v}") |
| return {"run_id": run_id, "best_epoch": best_epoch, "test_metrics": test_metrics} |
|
|
|
|
| |
| |
| |
| @app.function( |
| image=image, |
| cpu=1, |
| volumes={"/manifest": manifest_vol}, |
| timeout=300, |
| ) |
| def upload_manifest_to_volume(manifest_text: str, name: str = "amr_binary_v1.json"): |
| """Stuff the manifest JSON onto the volume so the trainer can read it.""" |
| out = Path("/manifest") / name |
| out.write_text(manifest_text) |
| manifest_vol.commit() |
| return {"path": str(out), "size_kb": out.stat().st_size / 1024} |
|
|
|
|
| |
| |
| |
| @app.local_entrypoint() |
| def main(epochs: int = 5, lr: float = 1e-3, pos_weight: float = 20.0): |
| """Upload the manifest if needed, then train. |
| |
| modal run probes/train_amr_binary_v1.py::main |
| modal run probes/train_amr_binary_v1.py::main --epochs 10 --lr 5e-4 |
| """ |
| text = Path(MANIFEST_PATH_LOCAL).read_text() |
| print("[local] uploading manifest to Modal volume...") |
| print(upload_manifest_to_volume.remote(text)) |
| print("[local] starting training...") |
| r = train_remote.remote(epochs=epochs, lr=lr, pos_weight=pos_weight) |
| print("\n=== RESULT ===") |
| print(json.dumps(r, indent=2)) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @app.function( |
| image=image, |
| cpu=2, |
| memory=8 * 1024, |
| volumes={ |
| "/data": l26_vol, |
| "/results": results_vol, |
| "/manifest": manifest_vol, |
| }, |
| timeout=900, |
| ) |
| def plot_probe_logits(run_id: str = "", n_examples: int = 5) -> bytes: |
| """Generate sanity-check plot of per-token logits. Returns PNG as bytes.""" |
| import io |
| import os |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| |
| base = "/results/amr_binary_v1/linear" |
| if not run_id: |
| run_ids = sorted(os.listdir(base)) |
| if not run_ids: |
| raise RuntimeError(f"no runs found under {base}") |
| run_id = run_ids[-1] |
| ckpt_path = f"{base}/{run_id}/checkpoint.pt" |
| print(f"[plot] loading checkpoint: {ckpt_path}") |
| state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| probe = nn.Linear(HIDDEN, 1) |
| probe.load_state_dict(state_dict) |
| probe.eval() |
|
|
| manifest = json.loads(open(MANIFEST_PATH_REMOTE).read()) |
| rng = np.random.default_rng(SEED) |
|
|
| |
| pos_ids = [ |
| rid for rid, lbl in manifest["labels_per_region"].items() |
| if lbl == 1 and manifest["region_split"][rid] in ("val", "test") |
| ] |
| rng.shuffle(pos_ids) |
| selected_pos = pos_ids[:n_examples] |
| selected_neg = [manifest["pair_partner"][p] for p in selected_pos] |
|
|
| |
| examples = [] |
| for rid in selected_pos + selected_neg: |
| gc = manifest["gene_coords"][rid] |
| rl = int(manifest["labels_per_region"][rid]) |
| acts, labels = load_region("/data", rid, rl, gc) |
| with torch.no_grad(): |
| logits = probe(torch.from_numpy(acts)).squeeze(-1).numpy() |
| examples.append({"rid": rid, "rl": rl, "gc": gc, "logits": logits, "labels": labels}) |
|
|
| |
| fig, axes = plt.subplots(n_examples, 2, figsize=(14, 2.4 * n_examples), sharex=False) |
| if n_examples == 1: |
| axes = np.array([axes]) |
| for col, half in enumerate([selected_pos, selected_neg]): |
| for row, rid in enumerate(half): |
| ex = next(e for e in examples if e["rid"] == rid) |
| ax = axes[row, col] |
| logits = ex["logits"] |
| positions = np.arange(len(logits)) |
| ax.plot(positions, logits, lw=0.6, color="black") |
| ax.axhline(0, color="grey", lw=0.6, ls="--") |
| gs, ge, es, ee, strand = ex["gc"] |
| cds0 = max(0, gs - es) |
| cds1 = min(len(logits), ge - es + 1) |
| shade_color = "tab:green" if ex["rl"] == 1 else "tab:red" |
| ax.axvspan(cds0, cds1, alpha=0.22, color=shade_color) |
| ax.tick_params(labelsize=12) |
| |
| if row == n_examples - 1: |
| ax.set_xlabel("token position", fontsize=14) |
| if col == 0: |
| ax.set_ylabel("logit", fontsize=14) |
| |
| axes[0, 0].set_title("AMR positive (CDS shaded green)", fontsize=14, fontweight="bold") |
| axes[0, 1].set_title("matched negative (CDS shaded red)", fontsize=14, fontweight="bold") |
| fig.suptitle("Linear probe on AMR — per-token logits", fontsize=18, fontweight="bold") |
| fig.tight_layout() |
|
|
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") |
| plt.close(fig) |
| return buf.getvalue() |
|
|
|
|
| |
| |
| |
| |
| |
| @app.function( |
| image=image, |
| cpu=2, |
| memory=8 * 1024, |
| volumes={"/data": l26_vol, "/results": results_vol, "/manifest": manifest_vol}, |
| timeout=900, |
| ) |
| def plot_score_distributions(run_id: str = "", split: str = "test") -> dict: |
| """1D distribution plots of per-region max-pool and mean-pool logits. |
| Returns dict with PNG bytes AND raw scores so reformatting doesn't require |
| a re-run.""" |
| import io |
| import os |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_auc_score, precision_recall_curve |
|
|
| base = "/results/amr_binary_v1/linear" |
| if not run_id: |
| run_id = sorted(os.listdir(base))[-1] |
| state_dict = torch.load(f"{base}/{run_id}/checkpoint.pt", map_location="cpu", weights_only=True) |
| probe = make_probe() |
| probe.load_state_dict(state_dict) |
| probe.eval() |
| print(f"[plot] loaded probe from run {run_id}") |
|
|
| manifest = json.loads(Path(MANIFEST_PATH_REMOTE).read_text()) |
|
|
| region_ids: list[str] = [] |
| region_max = [] |
| region_mean = [] |
| region_labels = [] |
|
|
| with torch.no_grad(): |
| for rid, acts, _labels in iter_split(manifest, split, "/data", shuffle=False): |
| x = torch.from_numpy(acts) |
| logits = probe(x).squeeze(-1).numpy() |
| region_ids.append(rid) |
| region_max.append(float(np.max(logits))) |
| region_mean.append(float(np.mean(logits))) |
| region_labels.append(int(manifest["labels_per_region"][rid])) |
| region_max = np.array(region_max) |
| region_mean = np.array(region_mean) |
| region_labels = np.array(region_labels) |
|
|
| |
| def stats(scores): |
| auc = roc_auc_score(region_labels, scores) |
| prec, rec, thr = precision_recall_curve(region_labels, scores) |
| f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9) |
| idx = int(np.argmax(f1)) |
| best_thr = float(thr[min(idx, len(thr) - 1)]) |
| return auc, best_thr, float(f1[idx]) |
|
|
| auc_max, t_max, f1_max = stats(region_max) |
| auc_mean, t_mean, f1_mean = stats(region_mean) |
| print(f"[plot] max-pool: AUC={auc_max:.4f} best-F1 thr={t_max:.3f} F1={f1_max:.3f}") |
| print(f"[plot] mean-pool: AUC={auc_mean:.4f} best-F1 thr={t_mean:.3f} F1={f1_mean:.3f}") |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| for ax, scores, name, auc, best_thr, f1 in [ |
| (axes[0], region_max, "max-pool", auc_max, t_max, f1_max), |
| (axes[1], region_mean, "mean-pool", auc_mean, t_mean, f1_mean), |
| ]: |
| pos = scores[region_labels == 1] |
| neg = scores[region_labels == 0] |
| |
| bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 60) |
| ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True) |
| ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True) |
| |
| rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) |
| ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6) |
| ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6) |
| |
| ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") |
| ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})") |
| ax.set_xlabel("per-region probe logit (= w · h_pooled + b)", fontsize=10) |
| ax.set_ylabel("density", fontsize=10) |
| ax.set_title(f"{name} • AUC={auc:.4f} • best-F1={f1:.3f}", fontsize=11) |
| ax.legend(fontsize=8, loc="upper left") |
| ax.grid(True, alpha=0.2) |
|
|
| fig.suptitle( |
| f"v1 AMR linear probe — per-region score distributions ({split} split, " |
| f"{int((region_labels==1).sum())} pos + {int((region_labels==0).sum())} neg, " |
| f"MAG-level held-out)\n" |
| "1D projection onto the probe's learned direction. Each region is one number " |
| "= the probe's logit summarised by max- or mean-pool over its tokens.", |
| fontsize=11, |
| ) |
| fig.tight_layout() |
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") |
| plt.close(fig) |
|
|
| |
| raw = [ |
| {"region_id": rid, "label": int(lab), |
| "max_logit": float(mx), "mean_logit": float(mn)} |
| for rid, lab, mx, mn in zip(region_ids, region_labels, region_max, region_mean) |
| ] |
| summary = { |
| "split": split, |
| "run_id": run_id, |
| "n_pos": int((region_labels == 1).sum()), |
| "n_neg": int((region_labels == 0).sum()), |
| "max_pool": {"auc": auc_max, "best_f1_threshold": t_max, "best_f1": f1_max}, |
| "mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean}, |
| } |
| return {"png": buf.getvalue(), "scores": raw, "summary": summary} |
|
|
|
|
| @app.local_entrypoint() |
| def make_score_dist_plot( |
| run_id: str = "", |
| split: str = "test", |
| out_path: str = "/home/ror25cal/MGnify/probes/results/amr_binary_v1_score_distributions.png", |
| ): |
| """Generate the per-region max/mean-pool logit distribution plot. |
| Also persists the raw per-region scores next to the PNG (JSONL + summary JSON).""" |
| result = plot_score_distributions.remote(run_id=run_id, split=split) |
| out_png = Path(out_path) |
| out_png.parent.mkdir(parents=True, exist_ok=True) |
| out_png.write_bytes(result["png"]) |
| print(f"saved {len(result['png'])/1024:.1f} KB to {out_png}") |
|
|
| out_jsonl = out_png.with_suffix(".scores.jsonl") |
| out_summary = out_png.with_suffix(".summary.json") |
| out_jsonl.write_text("\n".join(json.dumps(r) for r in result["scores"]) + "\n") |
| out_summary.write_text(json.dumps(result["summary"], indent=2)) |
| print(f"saved {len(result['scores'])} per-region scores to {out_jsonl}") |
| print(f"saved aggregate summary to {out_summary}") |
|
|
|
|
| @app.local_entrypoint() |
| def make_sanity_plot( |
| run_id: str = "", |
| out_path: str = "/home/ror25cal/MGnify/probes/results/amr_binary_v1_sanity_plot.png", |
| n_examples: int = 5, |
| ): |
| """Produce the sanity-check PNG of per-token logits and save locally. |
| |
| modal run probes/train_amr_binary_v1.py::make_sanity_plot |
| """ |
| print(f"[local] generating sanity plot for run_id={run_id or 'LATEST'}") |
| img_bytes = plot_probe_logits.remote(run_id=run_id, n_examples=n_examples) |
| Path(out_path).parent.mkdir(parents=True, exist_ok=True) |
| Path(out_path).write_bytes(img_bytes) |
| print(f"[local] saved {len(img_bytes)/1024:.1f} KB to {out_path}") |
|
|
|
|
| @app.local_entrypoint() |
| def dryrun_local(): |
| """Verify the per-token labelling, the linear probe forward/backward, the |
| metric calculation. Uses synthetic data — no network, no Modal.""" |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| print("=== DRY RUN (synthetic data) ===") |
|
|
| |
| fake_coords = [2000, 2624, 0, 4624, "+"] |
| labels = per_token_labels_from_record("FAKE_AMR", region_label=1, gene_coords=fake_coords, seq_len=4624) |
| n_pos = int(labels.sum()) |
| print(f"per-token labelling: {n_pos} positive tokens out of 4624 " |
| f"(expect 625 for gene 2000..2624 inclusive)") |
| assert n_pos == 625, f"expected 625, got {n_pos}" |
| |
| edge_labels = per_token_labels_from_record("FAKE", 1, [-100, 200, 0, 1000, "+"], seq_len=1000) |
| print(f"edge: gene_start before ext_start, got {int(edge_labels.sum())} positives (expect ~201)") |
| assert edge_labels[0] == 1 |
| |
| misc_labels = per_token_labels_from_record("FAKE_MISC", 0, [0, 100, 0, 1000, "+"], seq_len=1000) |
| assert misc_labels.sum() == 0 |
|
|
| |
| probe = make_probe().to("cpu") |
| optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3) |
| bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0)) |
| x = torch.randn(4624, 4096) |
| y = torch.from_numpy(labels).float() |
| logits = probe(x).squeeze(-1) |
| loss = bce(logits, y) |
| print(f"forward pass loss = {loss.item():.4f}") |
| loss.backward() |
| optimizer.step() |
| print("backward + optimizer step OK") |
|
|
| |
| fake_manifest = { |
| "region_split": {"R1": "val", "R2": "val"}, |
| "labels_per_region": {"R1": 1, "R2": 0}, |
| "gene_coords": { |
| "R1": [2000, 2624, 0, 4624, "+"], |
| "R2": [2000, 2624, 0, 4624, "+"], |
| }, |
| } |
| |
| |
|
|
| print("\nDRY RUN OK — logic paths working. Run `modal run main` for the real thing.") |
|
|