explcre commited on
Commit
d1d87a6
·
verified ·
1 Parent(s): d3316a5

Upload runs/exp_oracle_v3_binary7_separate_fast_h100/README.md with huggingface_hub

Browse files
runs/exp_oracle_v3_binary7_separate_fast_h100/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Oracle v3: LEONINE-strict 7-binary-classifier (joint top-1 = 0.973)
2
+
3
+ LEONINE-faithful enhancer cell-type classifier. 7 separate DeepSTARR-XL
4
+ networks (one per cell type), each trained as a binary classifier with
5
+ 1:1 pos:neg sampling. Joint inference: argmax over the 7 sigmoid outputs.
6
+
7
+ **Held-out test set (3500 rows balanced 500/cell):**
8
+ - Joint top-1: **0.973** (chance = 0.143)
9
+ - Mean AUROC: **0.992**
10
+ - Per-cell recall: Ex 0.97, In 0.98, OPC 0.98, Ast 0.98, Oli 0.95, Mic 0.97, End 0.98
11
+ - Per-cell AUROC: all ≥ 0.986
12
+
13
+ **Files:**
14
+ - `oracle.pt` — bundled checkpoint with state["per_cell"] = {cell: state_dict}
15
+ - `{Ex,In,OPC,Ast,Oli,Mic,End}/oracle.pt` — individual cell checkpoints
16
+ - `bundle_separate_oracle.py` — loader (SeparateBinaryOracle wrapper)
17
+ - `metrics.json` — per-cell training + joint eval metrics
18
+
19
+ **Loading:**
20
+ ```python
21
+ import torch
22
+ from bundle_separate_oracle import SeparateBinaryOracle, build_one_cell_model, CELL_TYPES
23
+
24
+ ckpt = torch.load("oracle.pt", map_location="cpu", weights_only=False)
25
+ nets = {}
26
+ for c in CELL_TYPES:
27
+ m = build_one_cell_model(c, input_length=600)
28
+ m.load_state_dict(ckpt["per_cell"][c], strict=True)
29
+ nets[c] = m
30
+ oracle = SeparateBinaryOracle(nets, input_length=600).cuda().eval()
31
+
32
+ # Two forward modes:
33
+ # (1) Differentiable: oracle(soft_dna_tensor) → (B, 7) logits
34
+ # (2) Standard: oracle(["ACGT...", ...]) → (B, 7) logits
35
+ # .embed(seqs) returns (B, fc_dim=1024) penultimate features for FID.
36
+ ```
37
+
38
+ **Architecture per cell:** DeepSTARR-XL backbone (4 conv blocks 256/256/128/120,
39
+ fc=1024, dropout=0.3) + 1-output binary head. Trained on
40
+ `oracle_train.7cell.fdr_both` with WeightedRandomSampler 1:1 pos:neg.
41
+
42
+ **Training data:** the same `oracle_train.7cell.fdr_both.jsonl` used for the
43
+ v2 regression oracle (kept for diff). Difference is the loss formulation
44
+ and training schedule, not the data.