explcre commited on
Commit
d3316a5
·
verified ·
1 Parent(s): b55c1bb

Upload runs/exp_oracle_v3_binary7_separate_fast_h100/bundle_separate_oracle.py with huggingface_hub

Browse files
runs/exp_oracle_v3_binary7_separate_fast_h100/bundle_separate_oracle.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bundle 7 separate per-cell binary classifiers into one oracle.pt and
2
+ compute the joint top-1 accuracy on the held-out test set.
3
+
4
+ Joint inference: for a batch of sequences, run all 7 networks in
5
+ parallel; sigmoid each output; argmax over the 7 sigmoid scores gives
6
+ the predicted cell type. This is the LEONINE-FID-protocol joint
7
+ classifier.
8
+
9
+ Bundle format:
10
+ state["per_cell"][cell_name] = state_dict for that cell's network
11
+ state["config"] = shared DeepSTARR config (cells_types=("CELL",) for
12
+ single-output head)
13
+ state["task"] = "classifier_binary7_separate_joint"
14
+ state["oracle"] = "deepstarr_7cell_separate"
15
+
16
+ To use as a drop-in oracle: load via load_separate_oracle() (defined
17
+ below) — returns a wrapper module with .forward(seqs) -> (B, 7) logits.
18
+ """
19
+ from __future__ import annotations
20
+ import json
21
+ import sys
22
+ import time
23
+ from pathlib import Path
24
+ from typing import Sequence
25
+
26
+ sys.path.insert(0, "/workspace/biomodel_reasoning_calling_study2/regureasoner_loop")
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ from regureasoner.benchmarks.oracles.deepstarr_7cell import (
34
+ DeepSTARR7Cell, DeepSTARR7CellConfig,
35
+ )
36
+ # vectorized one-hot
37
+ sys.path.insert(0, "/workspace/biomodel_reasoning_calling_study2/regureasoner_loop/scripts")
38
+ from train_oracle_binary7_fast import one_hot_batch_fast # noqa: E402
39
+
40
+ CELL_TYPES = ("Ex", "In", "OPC", "Ast", "Oli", "Mic", "End")
41
+ CELL_TO_IDX = {c: i for i, c in enumerate(CELL_TYPES)}
42
+
43
+
44
+ def build_one_cell_model(target_cell: str, input_length: int = 600) -> DeepSTARR7Cell:
45
+ """Same architecture as scripts/train_oracle_binary_one_cell.py used
46
+ (xlarge config + 1-output head)."""
47
+ cfg = DeepSTARR7CellConfig(
48
+ cell_types=(target_cell,),
49
+ input_length=input_length,
50
+ fc_dim=1024, dropout=0.3,
51
+ conv_channels=(256, 256, 128, 120),
52
+ conv_kernels=(7, 5, 5, 3),
53
+ )
54
+ model = DeepSTARR7Cell(cfg)
55
+ fc_in = model.head.in_features
56
+ model.head = nn.Linear(fc_in, 1)
57
+ return model
58
+
59
+
60
+ class SeparateBinaryOracle(nn.Module):
61
+ """Wrapper that holds 7 per-cell binary classifiers and produces
62
+ (B, 7) sigmoid scores via stacking.
63
+
64
+ Exposes encoder/dense/head so it works with the existing
65
+ oracle_aux_loss.py pathway WITHOUT changes:
66
+
67
+ h = oracle.encoder(soft_dna) # we don't actually use this; instead
68
+ we route soft_dna through each cell's
69
+ forward via a custom .head() impl.
70
+ h = oracle.dense(h)
71
+ logits = oracle.head(h) # (B, 7) — joint scores
72
+
73
+ To do this without forcing a Frankenstein "fake encoder", we expose:
74
+ .config.input_length, .config.cell_types, .num_cell_types
75
+ and patch oracle_aux_loss.py to call .forward(soft_dna) directly when
76
+ the oracle is a SeparateBinaryOracle. Simpler than mocking encoder/
77
+ dense/head.
78
+ """
79
+ def __init__(self, per_cell_models: dict, input_length: int = 600):
80
+ super().__init__()
81
+ # Order matters for argmax: keep CELL_TYPES order
82
+ self.cell_types = list(CELL_TYPES)
83
+ # Use ModuleDict so PyTorch tracks them
84
+ self.per_cell = nn.ModuleDict({c: per_cell_models[c] for c in CELL_TYPES})
85
+ self.num_cell_types = 7
86
+ # native_hidden = the per-cell penultimate width (used by FID's
87
+ # gaussian-stats step on the embed output)
88
+ first_net = next(iter(self.per_cell.values()))
89
+ # DeepSTARR7Cell.dense ends with an Identity-ish nn.Sequential whose
90
+ # last Linear has out_features = fc_dim. Pull from there.
91
+ self.native_hidden = int(first_net.config.fc_dim)
92
+ # Compatibility with downstream
93
+ from types import SimpleNamespace
94
+ self.config = SimpleNamespace(
95
+ input_length=input_length,
96
+ cell_types=tuple(CELL_TYPES),
97
+ )
98
+
99
+ def forward(self, soft_dna):
100
+ """Two input modes — both produce (B, 7) raw logits:
101
+ (1) soft_dna: (B, 4, L) tensor — the differentiable aux-loss path
102
+ (2) seqs: Sequence[str] — match OracleProtocol used by
103
+ the lab's existing celltype_specificity/compute_fid metrics
104
+ which expect oracle.forward(List[str]) → (B, C).
105
+ """
106
+ from regureasoner.benchmarks.oracles.base import one_hot_dna # noqa
107
+ if isinstance(soft_dna, torch.Tensor):
108
+ x = soft_dna
109
+ else:
110
+ # Tokenize strings to (B, 4, L)
111
+ seqs = list(soft_dna)
112
+ device = next(self.parameters()).device
113
+ x = torch.stack([one_hot_dna(s, self.config.input_length) for s in seqs])
114
+ x = x.to(device)
115
+ outs = []
116
+ for c in CELL_TYPES:
117
+ net = self.per_cell[c]
118
+ h = net.encoder(x).flatten(1)
119
+ h = net.dense(h)
120
+ logit = net.head(h).squeeze(-1) # (B,)
121
+ outs.append(logit)
122
+ return torch.stack(outs, dim=-1) # (B, 7)
123
+
124
+ @torch.no_grad()
125
+ def embed(self, seqs):
126
+ """OracleProtocol-compatible: (B,) strings → (B, fc_dim) penultimate
127
+ feature. Average per-cell penultimate activations gives a single
128
+ general-purpose embedding for FID."""
129
+ from regureasoner.benchmarks.oracles.base import one_hot_dna # noqa
130
+ self.eval()
131
+ device = next(self.parameters()).device
132
+ x = torch.stack([one_hot_dna(s, self.config.input_length) for s in seqs]).to(device)
133
+ feats = []
134
+ for c in CELL_TYPES:
135
+ net = self.per_cell[c]
136
+ h = net.encoder(x).flatten(1)
137
+ h = net.dense(h)
138
+ feats.append(h)
139
+ return torch.stack(feats, dim=0).mean(dim=0) # (B, fc_dim)
140
+
141
+
142
+ def evaluate_joint(oracle: SeparateBinaryOracle, eval_jsonl: str, device) -> dict:
143
+ oracle.eval()
144
+ rows = [json.loads(l) for l in open(eval_jsonl)]
145
+ preds, targets = [], []
146
+ bs = 256
147
+ for i in range(0, len(rows), bs):
148
+ chunk = rows[i:i+bs]
149
+ x = one_hot_batch_fast([r["sequence"] for r in chunk],
150
+ oracle.config.input_length).to(device)
151
+ with torch.no_grad():
152
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
153
+ logits = oracle(x) # (B, 7)
154
+ preds.append(torch.sigmoid(logits.float()).cpu().numpy())
155
+ for r in chunk:
156
+ ca = r["cell_activities"]
157
+ if len(ca) > 7: ca = ca[:7]
158
+ targets.append(int(np.argmax(ca)))
159
+ preds = np.concatenate(preds); targets = np.asarray(targets)
160
+
161
+ pred_idx = preds.argmax(axis=1); top1 = float((pred_idx == targets).mean())
162
+ pcr={}; pca={}
163
+ for c, name in enumerate(CELL_TYPES):
164
+ mask = targets == c
165
+ pcr[name] = float((pred_idx[mask]==c).mean()) if mask.any() else float("nan")
166
+ scores = preds[:, c]; labels = (targets == c).astype(int)
167
+ pos = scores[labels == 1]; neg = scores[labels == 0]
168
+ if len(pos) and len(neg):
169
+ all_s = np.concatenate([pos, neg])
170
+ ranks = (-all_s).argsort().argsort() + 1
171
+ n_pos = len(pos); n_neg = len(neg)
172
+ U = ranks[:n_pos].sum() - n_pos*(n_pos+1)/2
173
+ pca[name] = float(1.0 - (U/(n_pos*n_neg)))
174
+ else:
175
+ pca[name] = float("nan")
176
+ return {
177
+ "joint_top1": top1,
178
+ "mean_auroc": float(np.nanmean(list(pca.values()))),
179
+ "per_cell_recall": pcr,
180
+ "per_cell_auroc": pca,
181
+ }
182
+
183
+
184
+ def main():
185
+ out_dir = Path("/workspace/dnathinker/runs/exp_oracle_v3_binary7_separate_fast_h100")
186
+ print(f"[load] bundling {out_dir}")
187
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
188
+
189
+ per_cell_state: dict = {}
190
+ per_cell_metrics: dict = {}
191
+ per_cell_models = {}
192
+ for cell in CELL_TYPES:
193
+ ck = out_dir / cell / "oracle.pt"
194
+ if not ck.exists():
195
+ print(f" missing {cell} ckpt: {ck}")
196
+ continue
197
+ s = torch.load(ck, map_location="cpu", weights_only=False)
198
+ per_cell_state[cell] = s["state"]
199
+ per_cell_metrics[cell] = s.get("best_auroc")
200
+ m = build_one_cell_model(cell, input_length=600)
201
+ m.load_state_dict(s["state"], strict=True)
202
+ m = m.to(device).eval()
203
+ per_cell_models[cell] = m
204
+ print(f" {cell}: best_auroc={s.get('best_auroc'):.4f}")
205
+
206
+ if len(per_cell_models) < 7:
207
+ print(f"[error] only got {len(per_cell_models)} of 7 cells; aborting bundle")
208
+ sys.exit(1)
209
+
210
+ oracle = SeparateBinaryOracle(per_cell_models).to(device)
211
+ print(f"\n[eval] computing joint top-1 on test set")
212
+ eval_jsonl = "/workspace/dnathinker/data/oracle/oracle_test.7cell.500.jsonl"
213
+ metrics = evaluate_joint(oracle, eval_jsonl, device)
214
+ print(f" joint_top1 = {metrics['joint_top1']:.4f}")
215
+ print(f" mean_auroc = {metrics['mean_auroc']:.4f}")
216
+ print(f" per_cell recall: {metrics['per_cell_recall']}")
217
+ print(f" per_cell AUROC : {metrics['per_cell_auroc']}")
218
+
219
+ bundle = {
220
+ "per_cell": per_cell_state,
221
+ "config": {
222
+ "task": "classifier_binary7_separate",
223
+ "cell_types": list(CELL_TYPES),
224
+ "input_length": 600,
225
+ "deepstarr_fc_dim": 1024,
226
+ "deepstarr_dropout": 0.3,
227
+ "deepstarr_conv_channels": [256, 256, 128, 120],
228
+ "deepstarr_conv_kernels": [7, 5, 5, 3],
229
+ "deepstarr_pool_kernels": [3, 3, 3, 3],
230
+ "fc_dim": 1024, "dropout": 0.3,
231
+ "conv_channels": [256, 256, 128, 120],
232
+ "conv_kernels": [7, 5, 5, 3],
233
+ "pool_kernels": [3, 3, 3, 3],
234
+ },
235
+ "metrics": metrics,
236
+ "per_cell_train_metrics": per_cell_metrics,
237
+ "oracle": "deepstarr_7cell_separate",
238
+ }
239
+ bundle_path = out_dir / "oracle.pt"
240
+ torch.save(bundle, bundle_path)
241
+ print(f"\n[done] saved bundle to {bundle_path}")
242
+ print(f" joint_top1 = {metrics['joint_top1']:.4f}")
243
+ print(f" mean_auroc = {metrics['mean_auroc']:.4f}")
244
+
245
+ # Also save metrics.json for the meta indexer
246
+ with open(out_dir / "metrics.json", "w") as f:
247
+ json.dump({"joint_top1": metrics["joint_top1"],
248
+ "mean_auroc": metrics["mean_auroc"],
249
+ "per_cell": metrics["per_cell_auroc"],
250
+ "per_cell_train": per_cell_metrics}, f, indent=2)
251
+ with open(out_dir / "_arch_meta.json", "w") as f:
252
+ json.dump({"task": "oracle_classifier_separate",
253
+ "kind": "v3_binary7_separate",
254
+ "label": f"LEONINE-strict 7 separate networks (joint top1 {metrics['joint_top1']:.3f})"},
255
+ f)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()