Upload runs/exp_oracle_v3_binary7_separate_fast_h100/README.md with huggingface_hub
Browse files
runs/exp_oracle_v3_binary7_separate_fast_h100/README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Oracle v3: LEONINE-strict 7-binary-classifier (joint top-1 = 0.973)
|
| 2 |
+
|
| 3 |
+
LEONINE-faithful enhancer cell-type classifier. 7 separate DeepSTARR-XL
|
| 4 |
+
networks (one per cell type), each trained as a binary classifier with
|
| 5 |
+
1:1 pos:neg sampling. Joint inference: argmax over the 7 sigmoid outputs.
|
| 6 |
+
|
| 7 |
+
**Held-out test set (3500 rows balanced 500/cell):**
|
| 8 |
+
- Joint top-1: **0.973** (chance = 0.143)
|
| 9 |
+
- Mean AUROC: **0.992**
|
| 10 |
+
- Per-cell recall: Ex 0.97, In 0.98, OPC 0.98, Ast 0.98, Oli 0.95, Mic 0.97, End 0.98
|
| 11 |
+
- Per-cell AUROC: all ≥ 0.986
|
| 12 |
+
|
| 13 |
+
**Files:**
|
| 14 |
+
- `oracle.pt` — bundled checkpoint with state["per_cell"] = {cell: state_dict}
|
| 15 |
+
- `{Ex,In,OPC,Ast,Oli,Mic,End}/oracle.pt` — individual cell checkpoints
|
| 16 |
+
- `bundle_separate_oracle.py` — loader (SeparateBinaryOracle wrapper)
|
| 17 |
+
- `metrics.json` — per-cell training + joint eval metrics
|
| 18 |
+
|
| 19 |
+
**Loading:**
|
| 20 |
+
```python
|
| 21 |
+
import torch
|
| 22 |
+
from bundle_separate_oracle import SeparateBinaryOracle, build_one_cell_model, CELL_TYPES
|
| 23 |
+
|
| 24 |
+
ckpt = torch.load("oracle.pt", map_location="cpu", weights_only=False)
|
| 25 |
+
nets = {}
|
| 26 |
+
for c in CELL_TYPES:
|
| 27 |
+
m = build_one_cell_model(c, input_length=600)
|
| 28 |
+
m.load_state_dict(ckpt["per_cell"][c], strict=True)
|
| 29 |
+
nets[c] = m
|
| 30 |
+
oracle = SeparateBinaryOracle(nets, input_length=600).cuda().eval()
|
| 31 |
+
|
| 32 |
+
# Two forward modes:
|
| 33 |
+
# (1) Differentiable: oracle(soft_dna_tensor) → (B, 7) logits
|
| 34 |
+
# (2) Standard: oracle(["ACGT...", ...]) → (B, 7) logits
|
| 35 |
+
# .embed(seqs) returns (B, fc_dim=1024) penultimate features for FID.
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
**Architecture per cell:** DeepSTARR-XL backbone (4 conv blocks 256/256/128/120,
|
| 39 |
+
fc=1024, dropout=0.3) + 1-output binary head. Trained on
|
| 40 |
+
`oracle_train.7cell.fdr_both` with WeightedRandomSampler 1:1 pos:neg.
|
| 41 |
+
|
| 42 |
+
**Training data:** the same `oracle_train.7cell.fdr_both.jsonl` used for the
|
| 43 |
+
v2 regression oracle (kept for diff). Difference is the loss formulation
|
| 44 |
+
and training schedule, not the data.
|