""" Linear per-token probe for AMR-vs-matched-CDS binary classification. Runs on Modal CPU. Reads activations from `mgnify-embeddings-l26-lean`. Per-token labels: - For AMR-positive records: token at position p is labelled 1 iff ext_start + p ∈ [gene_start, gene_end]. - For matched-MISC records: all tokens labelled 0. (See HACKATHON_STATUS.md and probes/splits/build_amr_binary_v1.py for split design.) Usage: modal run probes/train_amr_binary_v1.py::main --epochs 5 modal run probes/train_amr_binary_v1.py::dryrun_local # synthetic-data sanity check """ from __future__ import annotations import json import os import random import time from pathlib import Path from typing import Iterator import modal # --------------------------------------------------------------------------- # Constants / config # --------------------------------------------------------------------------- MANIFEST_PATH_LOCAL = "/home/ror25cal/MGnify/probes/splits/amr_binary_v1.json" MANIFEST_PATH_REMOTE = "/manifest/amr_binary_v1.json" RESULTS_DIR_LOCAL = "/home/ror25cal/MGnify/probes/results/amr_binary_v1/linear" HIDDEN = 4096 SEED = 42 # --------------------------------------------------------------------------- # Modal app + image. Pure CPU — no GPU, no Evo 2 needed. # --------------------------------------------------------------------------- image = ( modal.Image.debian_slim() .pip_install("numpy", "torch>=2.0", "scikit-learn>=1.3", "matplotlib") ) l26_vol = modal.Volume.from_name("mgnify-embeddings-l26-lean", create_if_missing=False) results_vol = modal.Volume.from_name("mgnify-probe-results", create_if_missing=True) manifest_vol = modal.Volume.from_name("mgnify-probe-manifests", create_if_missing=True) app = modal.App("mgnify-amr-probe-v1") # --------------------------------------------------------------------------- # Per-token label computation # --------------------------------------------------------------------------- def per_token_labels_from_record( region_id: str, region_label: int, # 1 if AMR-positive, 0 if MISC gene_coords: list, # [gene_start, gene_end, ext_start, ext_end, strand] seq_len: int, ) -> "np.ndarray": """Return an int8 array of shape [seq_len], 1 where token position is inside the gene CDS for AMR-positive records, else 0.""" import numpy as np if region_label == 0: return np.zeros(seq_len, dtype=np.int8) gene_start, gene_end, ext_start, ext_end, _strand = gene_coords # Positions in the extracted sequence are [ext_start .. ext_end] inclusive # in genomic coordinates. token p of the npz corresponds to genomic # coordinate (ext_start + p). We don't reverse-complement; tokens in the # gene body are those whose genomic coord is in [gene_start, gene_end]. labels = np.zeros(seq_len, dtype=np.int8) in_gene_start = max(0, gene_start - ext_start) in_gene_end = min(seq_len, gene_end - ext_start + 1) if in_gene_end > in_gene_start: labels[in_gene_start:in_gene_end] = 1 return labels # --------------------------------------------------------------------------- # Load one .npz: returns (acts_fp32 [seq_len, 4096], labels [seq_len]). # --------------------------------------------------------------------------- def load_region(npz_root: str, region_id: str, region_label: int, gene_coords: list): import numpy as np import torch label_folder = "AMR" if region_label == 1 else "MISC" mag_id = region_id.rsplit("_", 2)[0] # MGYG..._00123_AMR -> MGYG... npz_path = f"{npz_root}/{label_folder}/{mag_id}/{region_id}.npz" d = np.load(npz_path, allow_pickle=False) acts = torch.from_numpy(d["layer26_activations_bf16"]).view(torch.bfloat16).float().numpy() seq_len = acts.shape[0] labels = per_token_labels_from_record(region_id, region_label, gene_coords, seq_len) return acts, labels # --------------------------------------------------------------------------- # Streaming iterator: yield (acts_batch [B, 4096], labels_batch [B]) for the # given split. Each "batch" is one region's worth of tokens (variable B). # --------------------------------------------------------------------------- def iter_split( manifest: dict, split: str, npz_root: str, shuffle: bool = True, seed: int = SEED, ) -> Iterator[tuple]: region_ids = [r for r, s in manifest["region_split"].items() if s == split] if shuffle: rng = random.Random(seed) rng.shuffle(region_ids) labels_per_region = manifest["labels_per_region"] gene_coords = manifest["gene_coords"] for rid in region_ids: try: acts, labels = load_region( npz_root, rid, region_label=int(labels_per_region[rid]), gene_coords=gene_coords[rid], ) except FileNotFoundError: print(f" WARN: missing {rid}; skipping") continue yield rid, acts, labels # --------------------------------------------------------------------------- # Linear probe # --------------------------------------------------------------------------- def make_probe(hidden: int = HIDDEN): import torch.nn as nn return nn.Linear(hidden, 1) # --------------------------------------------------------------------------- # Train one epoch over the train split. # --------------------------------------------------------------------------- def train_one_epoch(probe, optimizer, manifest, npz_root, pos_weight, device, max_regions=None): import torch import torch.nn as nn probe.train() bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(float(pos_weight), device=device)) total_loss = 0.0 total_tokens = 0 n_pos_seen = 0 n_regions = 0 t0 = time.time() for rid, acts, labels in iter_split(manifest, "train", npz_root, shuffle=True, seed=SEED + 7): n_regions += 1 if max_regions is not None and n_regions > max_regions: break x = torch.from_numpy(acts).to(device) # [seq_len, 4096] y = torch.from_numpy(labels).float().to(device) # [seq_len] logits = probe(x).squeeze(-1) # [seq_len] loss = bce(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * x.shape[0] total_tokens += x.shape[0] n_pos_seen += int(y.sum().item()) if n_regions % 200 == 0: print(f" epoch progress: {n_regions} regions, " f"avg loss so far {total_loss/total_tokens:.4f}, " f"pos rate {n_pos_seen/total_tokens:.4f}") elapsed = time.time() - t0 print(f" epoch done: {n_regions} regions, {total_tokens} tokens, " f"avg loss {total_loss/total_tokens:.4f}, {elapsed:.0f}s") return total_loss / max(total_tokens, 1) # --------------------------------------------------------------------------- # Eval on val/test split. Returns metric dict. # --------------------------------------------------------------------------- def evaluate(probe, manifest, npz_root, split, device): import torch import numpy as np from sklearn.metrics import ( roc_auc_score, average_precision_score, f1_score, precision_recall_curve, ) probe.eval() all_logits = [] all_labels = [] all_region_ids = [] all_region_labels = [] with torch.no_grad(): for rid, acts, labels in iter_split(manifest, split, npz_root, shuffle=False): x = torch.from_numpy(acts).to(device) logits = probe(x).squeeze(-1).cpu().numpy() all_logits.append(logits) all_labels.append(labels) all_region_ids.append(rid) all_region_labels.append(int(manifest["labels_per_region"][rid])) logits = np.concatenate(all_logits) labels = np.concatenate(all_labels) # Per-token metrics token_auc = roc_auc_score(labels, logits) token_pr_auc = average_precision_score(labels, logits) # Find best F1 threshold precision, recall, thresholds = precision_recall_curve(labels, logits) f1s = 2 * precision * recall / np.maximum(precision + recall, 1e-9) best_idx = int(np.argmax(f1s)) best_thresh = float(thresholds[min(best_idx, len(thresholds) - 1)]) token_f1_best = float(f1s[best_idx]) # Per-region metrics: max-over-tokens (a region is predicted positive if any # token's logit clears threshold) AND mean-over-tokens. Standard ways to # aggregate per-token to per-region. region_max_logit = [] region_mean_logit = [] cursor = 0 for arr in all_logits: n = arr.shape[0] region_max_logit.append(float(np.max(arr))) region_mean_logit.append(float(np.mean(arr))) cursor += n region_labels_np = np.array(all_region_labels) region_max_auc = roc_auc_score(region_labels_np, region_max_logit) region_mean_auc = roc_auc_score(region_labels_np, region_mean_logit) return { "split": split, "n_regions": len(all_region_ids), "n_tokens": int(labels.shape[0]), "n_positive_tokens": int(labels.sum()), "token_roc_auc": float(token_auc), "token_pr_auc": float(token_pr_auc), "token_best_f1": token_f1_best, "token_best_threshold": best_thresh, "region_max_pool_auc": float(region_max_auc), "region_mean_pool_auc": float(region_mean_auc), } # --------------------------------------------------------------------------- # Modal training entrypoint # --------------------------------------------------------------------------- @app.function( image=image, cpu=4, memory=32 * 1024, # 32 GB volumes={ "/data": l26_vol, "/results": results_vol, "/manifest": manifest_vol, }, timeout=7200, ) def train_remote(epochs: int = 5, lr: float = 1e-3, pos_weight: float = 20.0, run_id: str = ""): import torch import numpy as np torch.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) print(f"[probe] reading manifest from {MANIFEST_PATH_REMOTE}") manifest = json.loads(Path(MANIFEST_PATH_REMOTE).read_text()) device = torch.device("cpu") probe = make_probe().to(device) optimizer = torch.optim.Adam(probe.parameters(), lr=lr) print(f"[probe] linear probe: in={HIDDEN}, out=1, params={sum(p.numel() for p in probe.parameters())}") val_metrics_history = [] best_val_auc = -1.0 best_epoch = -1 best_state = None for epoch in range(epochs): print(f"\n=== EPOCH {epoch + 1} / {epochs} ===") train_loss = train_one_epoch(probe, optimizer, manifest, "/data", pos_weight, device) print(f" evaluating on val split...") val_metrics = evaluate(probe, manifest, "/data", "val", device) val_metrics["epoch"] = epoch + 1 val_metrics["train_loss"] = train_loss val_metrics_history.append(val_metrics) print(f" val token AUC: {val_metrics['token_roc_auc']:.4f}, " f"region max-pool AUC: {val_metrics['region_max_pool_auc']:.4f}") if val_metrics["token_roc_auc"] > best_val_auc: best_val_auc = val_metrics["token_roc_auc"] best_epoch = epoch + 1 best_state = {k: v.cpu().clone() for k, v in probe.state_dict().items()} print(f"\n[probe] best val token AUC = {best_val_auc:.4f} at epoch {best_epoch}") print("[probe] running final test eval with best epoch's weights") if best_state is not None: probe.load_state_dict(best_state) test_metrics = evaluate(probe, manifest, "/data", "test", device) # Save results run_id = run_id or time.strftime("%Y%m%d_%H%M%S") out_dir = f"/results/amr_binary_v1/linear/{run_id}" os.makedirs(out_dir, exist_ok=True) torch.save(best_state, f"{out_dir}/checkpoint.pt") Path(f"{out_dir}/metrics.json").write_text(json.dumps({ "manifest": "amr_binary_v1", "model": "linear", "hyperparameters": {"epochs": epochs, "lr": lr, "pos_weight": pos_weight, "seed": SEED}, "best_epoch": best_epoch, "val_history": val_metrics_history, "test": test_metrics, }, indent=2)) results_vol.commit() print(f"\n[probe] results written to {out_dir}/metrics.json") print(f"\n=== FINAL TEST METRICS ===") for k, v in test_metrics.items(): print(f" {k}: {v}") return {"run_id": run_id, "best_epoch": best_epoch, "test_metrics": test_metrics} # --------------------------------------------------------------------------- # Helper: upload manifest to its own volume (one-time per manifest version). # --------------------------------------------------------------------------- @app.function( image=image, cpu=1, volumes={"/manifest": manifest_vol}, timeout=300, ) def upload_manifest_to_volume(manifest_text: str, name: str = "amr_binary_v1.json"): """Stuff the manifest JSON onto the volume so the trainer can read it.""" out = Path("/manifest") / name out.write_text(manifest_text) manifest_vol.commit() return {"path": str(out), "size_kb": out.stat().st_size / 1024} # --------------------------------------------------------------------------- # Local entrypoint: upload manifest + kick off training. # --------------------------------------------------------------------------- @app.local_entrypoint() def main(epochs: int = 5, lr: float = 1e-3, pos_weight: float = 20.0): """Upload the manifest if needed, then train. modal run probes/train_amr_binary_v1.py::main modal run probes/train_amr_binary_v1.py::main --epochs 10 --lr 5e-4 """ text = Path(MANIFEST_PATH_LOCAL).read_text() print("[local] uploading manifest to Modal volume...") print(upload_manifest_to_volume.remote(text)) print("[local] starting training...") r = train_remote.remote(epochs=epochs, lr=lr, pos_weight=pos_weight) print("\n=== RESULT ===") print(json.dumps(r, indent=2)) # --------------------------------------------------------------------------- # Synthetic-data dry run on local CPU. Verifies all logic paths without any # network or Modal calls. ~10 seconds. # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Sanity-check plot: per-token probe logits for N example regions. # Loads the latest probe checkpoint, picks 5 AMR positives + their 5 matched # negatives from val/test (so we plot truly held-out examples), runs probe # forward, plots per-token logit vs position with the CDS interval shaded. # Returns PNG bytes. # --------------------------------------------------------------------------- @app.function( image=image, cpu=2, memory=8 * 1024, volumes={ "/data": l26_vol, "/results": results_vol, "/manifest": manifest_vol, }, timeout=900, ) def plot_probe_logits(run_id: str = "", n_examples: int = 5) -> bytes: """Generate sanity-check plot of per-token logits. Returns PNG as bytes.""" import io import os import json import numpy as np import torch import torch.nn as nn import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # Resolve checkpoint: latest run if none specified base = "/results/amr_binary_v1/linear" if not run_id: run_ids = sorted(os.listdir(base)) if not run_ids: raise RuntimeError(f"no runs found under {base}") run_id = run_ids[-1] ckpt_path = f"{base}/{run_id}/checkpoint.pt" print(f"[plot] loading checkpoint: {ckpt_path}") state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) probe = nn.Linear(HIDDEN, 1) probe.load_state_dict(state_dict) probe.eval() manifest = json.loads(open(MANIFEST_PATH_REMOTE).read()) rng = np.random.default_rng(SEED) # 5 AMR positives from val/test, with their paired negatives pos_ids = [ rid for rid, lbl in manifest["labels_per_region"].items() if lbl == 1 and manifest["region_split"][rid] in ("val", "test") ] rng.shuffle(pos_ids) selected_pos = pos_ids[:n_examples] selected_neg = [manifest["pair_partner"][p] for p in selected_pos] # Compute logits for each examples = [] for rid in selected_pos + selected_neg: gc = manifest["gene_coords"][rid] rl = int(manifest["labels_per_region"][rid]) acts, labels = load_region("/data", rid, rl, gc) with torch.no_grad(): logits = probe(torch.from_numpy(acts)).squeeze(-1).numpy() examples.append({"rid": rid, "rl": rl, "gc": gc, "logits": logits, "labels": labels}) # plot — 2 cols (positive | negative), n_examples rows. Pitch styling. fig, axes = plt.subplots(n_examples, 2, figsize=(14, 2.4 * n_examples), sharex=False) if n_examples == 1: axes = np.array([axes]) for col, half in enumerate([selected_pos, selected_neg]): for row, rid in enumerate(half): ex = next(e for e in examples if e["rid"] == rid) ax = axes[row, col] logits = ex["logits"] positions = np.arange(len(logits)) ax.plot(positions, logits, lw=0.6, color="black") ax.axhline(0, color="grey", lw=0.6, ls="--") gs, ge, es, ee, strand = ex["gc"] cds0 = max(0, gs - es) cds1 = min(len(logits), ge - es + 1) shade_color = "tab:green" if ex["rl"] == 1 else "tab:red" ax.axvspan(cds0, cds1, alpha=0.22, color=shade_color) ax.tick_params(labelsize=12) # Only show axis labels on the outer cells to reduce clutter if row == n_examples - 1: ax.set_xlabel("token position", fontsize=14) if col == 0: ax.set_ylabel("logit", fontsize=14) # Column headers (top of each column) axes[0, 0].set_title("AMR positive (CDS shaded green)", fontsize=14, fontweight="bold") axes[0, 1].set_title("matched negative (CDS shaded red)", fontsize=14, fontweight="bold") fig.suptitle("Linear probe on AMR — per-token logits", fontsize=18, fontweight="bold") fig.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") plt.close(fig) return buf.getvalue() # --------------------------------------------------------------------------- # Hyperplane-projection 1D plots: mean-pool and max-pool per-region logits # for the v1 AMR linear probe on the test split. Shows pos/neg distributions # on the same axis the classifier uses, with decision boundaries marked. # --------------------------------------------------------------------------- @app.function( image=image, cpu=2, memory=8 * 1024, volumes={"/data": l26_vol, "/results": results_vol, "/manifest": manifest_vol}, timeout=900, ) def plot_score_distributions(run_id: str = "", split: str = "test") -> dict: """1D distribution plots of per-region max-pool and mean-pool logits. Returns dict with PNG bytes AND raw scores so reformatting doesn't require a re-run.""" import io import os import json import numpy as np import torch import torch.nn as nn import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from sklearn.metrics import roc_auc_score, precision_recall_curve base = "/results/amr_binary_v1/linear" if not run_id: run_id = sorted(os.listdir(base))[-1] state_dict = torch.load(f"{base}/{run_id}/checkpoint.pt", map_location="cpu", weights_only=True) probe = make_probe() probe.load_state_dict(state_dict) probe.eval() print(f"[plot] loaded probe from run {run_id}") manifest = json.loads(Path(MANIFEST_PATH_REMOTE).read_text()) region_ids: list[str] = [] region_max = [] region_mean = [] region_labels = [] with torch.no_grad(): for rid, acts, _labels in iter_split(manifest, split, "/data", shuffle=False): x = torch.from_numpy(acts) logits = probe(x).squeeze(-1).numpy() region_ids.append(rid) region_max.append(float(np.max(logits))) region_mean.append(float(np.mean(logits))) region_labels.append(int(manifest["labels_per_region"][rid])) region_max = np.array(region_max) region_mean = np.array(region_mean) region_labels = np.array(region_labels) # AUCs + best-F1 thresholds for each aggregation def stats(scores): auc = roc_auc_score(region_labels, scores) prec, rec, thr = precision_recall_curve(region_labels, scores) f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9) idx = int(np.argmax(f1)) best_thr = float(thr[min(idx, len(thr) - 1)]) return auc, best_thr, float(f1[idx]) auc_max, t_max, f1_max = stats(region_max) auc_mean, t_mean, f1_mean = stats(region_mean) print(f"[plot] max-pool: AUC={auc_max:.4f} best-F1 thr={t_max:.3f} F1={f1_max:.3f}") print(f"[plot] mean-pool: AUC={auc_mean:.4f} best-F1 thr={t_mean:.3f} F1={f1_mean:.3f}") # Render — 2 panels side-by-side fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for ax, scores, name, auc, best_thr, f1 in [ (axes[0], region_max, "max-pool", auc_max, t_max, f1_max), (axes[1], region_mean, "mean-pool", auc_mean, t_mean, f1_mean), ]: pos = scores[region_labels == 1] neg = scores[region_labels == 0] # KDE / histogram per class bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 60) ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True) ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True) # Rug plots at the very bottom rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6) ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6) # Decision boundaries ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})") ax.set_xlabel("per-region probe logit (= w · h_pooled + b)", fontsize=10) ax.set_ylabel("density", fontsize=10) ax.set_title(f"{name} • AUC={auc:.4f} • best-F1={f1:.3f}", fontsize=11) ax.legend(fontsize=8, loc="upper left") ax.grid(True, alpha=0.2) fig.suptitle( f"v1 AMR linear probe — per-region score distributions ({split} split, " f"{int((region_labels==1).sum())} pos + {int((region_labels==0).sum())} neg, " f"MAG-level held-out)\n" "1D projection onto the probe's learned direction. Each region is one number " "= the probe's logit summarised by max- or mean-pool over its tokens.", fontsize=11, ) fig.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=120, bbox_inches="tight") plt.close(fig) # Raw per-region scores so the plot can be reformatted without re-running. raw = [ {"region_id": rid, "label": int(lab), "max_logit": float(mx), "mean_logit": float(mn)} for rid, lab, mx, mn in zip(region_ids, region_labels, region_max, region_mean) ] summary = { "split": split, "run_id": run_id, "n_pos": int((region_labels == 1).sum()), "n_neg": int((region_labels == 0).sum()), "max_pool": {"auc": auc_max, "best_f1_threshold": t_max, "best_f1": f1_max}, "mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean}, } return {"png": buf.getvalue(), "scores": raw, "summary": summary} @app.local_entrypoint() def make_score_dist_plot( run_id: str = "", split: str = "test", out_path: str = "/home/ror25cal/MGnify/probes/results/amr_binary_v1_score_distributions.png", ): """Generate the per-region max/mean-pool logit distribution plot. Also persists the raw per-region scores next to the PNG (JSONL + summary JSON).""" result = plot_score_distributions.remote(run_id=run_id, split=split) out_png = Path(out_path) out_png.parent.mkdir(parents=True, exist_ok=True) out_png.write_bytes(result["png"]) print(f"saved {len(result['png'])/1024:.1f} KB to {out_png}") out_jsonl = out_png.with_suffix(".scores.jsonl") out_summary = out_png.with_suffix(".summary.json") out_jsonl.write_text("\n".join(json.dumps(r) for r in result["scores"]) + "\n") out_summary.write_text(json.dumps(result["summary"], indent=2)) print(f"saved {len(result['scores'])} per-region scores to {out_jsonl}") print(f"saved aggregate summary to {out_summary}") @app.local_entrypoint() def make_sanity_plot( run_id: str = "", out_path: str = "/home/ror25cal/MGnify/probes/results/amr_binary_v1_sanity_plot.png", n_examples: int = 5, ): """Produce the sanity-check PNG of per-token logits and save locally. modal run probes/train_amr_binary_v1.py::make_sanity_plot """ print(f"[local] generating sanity plot for run_id={run_id or 'LATEST'}") img_bytes = plot_probe_logits.remote(run_id=run_id, n_examples=n_examples) Path(out_path).parent.mkdir(parents=True, exist_ok=True) Path(out_path).write_bytes(img_bytes) print(f"[local] saved {len(img_bytes)/1024:.1f} KB to {out_path}") @app.local_entrypoint() def dryrun_local(): """Verify the per-token labelling, the linear probe forward/backward, the metric calculation. Uses synthetic data — no network, no Modal.""" import numpy as np import torch import torch.nn as nn print("=== DRY RUN (synthetic data) ===") # Synthetic per-token labels from a fake AMR record fake_coords = [2000, 2624, 0, 4624, "+"] # gene_start, gene_end, ext_start, ext_end, strand labels = per_token_labels_from_record("FAKE_AMR", region_label=1, gene_coords=fake_coords, seq_len=4624) n_pos = int(labels.sum()) print(f"per-token labelling: {n_pos} positive tokens out of 4624 " f"(expect 625 for gene 2000..2624 inclusive)") assert n_pos == 625, f"expected 625, got {n_pos}" # Edge case: clamping edge_labels = per_token_labels_from_record("FAKE", 1, [-100, 200, 0, 1000, "+"], seq_len=1000) print(f"edge: gene_start before ext_start, got {int(edge_labels.sum())} positives (expect ~201)") assert edge_labels[0] == 1 # MISC always all-zero misc_labels = per_token_labels_from_record("FAKE_MISC", 0, [0, 100, 0, 1000, "+"], seq_len=1000) assert misc_labels.sum() == 0 # Linear probe forward/backward probe = make_probe().to("cpu") optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3) bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(20.0)) x = torch.randn(4624, 4096) y = torch.from_numpy(labels).float() logits = probe(x).squeeze(-1) loss = bce(logits, y) print(f"forward pass loss = {loss.item():.4f}") loss.backward() optimizer.step() print("backward + optimizer step OK") # Synthetic eval: build a "split" with two regions, compute metrics fake_manifest = { "region_split": {"R1": "val", "R2": "val"}, "labels_per_region": {"R1": 1, "R2": 0}, "gene_coords": { "R1": [2000, 2624, 0, 4624, "+"], "R2": [2000, 2624, 0, 4624, "+"], # not used since R2 is MISC }, } # We can't actually load .npz files in dry-run, so test evaluate() with # patched load_region. Skip — only run on Modal where data is present. print("\nDRY RUN OK — logic paths working. Run `modal run main` for the real thing.")