File size: 12,707 Bytes
aeb566f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 | # T3 β post-v5 follow-ups (deep-think + work plan)
The first-pass v5 work (oracle eval, multi-turn RFT, reasoning expansion,
sanitiser) is shipped. This doc captures the deeper questions the T3
rethink raises, what's already been done about them, and what's
deliberately deferred.
## 1. RFT timing β when, relative to T1/T2 + joint training?
**TL;DR**: Current order (per-task SFT β T3 RFT β joint multitask) is
the right *minimum*. A second RFT pass after the joint adapter is a
worthwhile **ablation** but adds a serial dependency.
### The current order
```
Stage 1 T1 fusion-SFT (heuristic gold)
Stage 2 T2 fusion-SFT
Stage 3 T3 fusion-SFT (heuristic synthetic-edit gold)
Stage 3b T3 reasoning-only SFT (paper ablation)
Stage 3c RFT β generate K candidates from Stage-3 adapter,
oracle-filter, retrain on filtered set
Stage 4 Joint multitask fusion-SFT (T1+T2+T3, 35k each, balanced)
```
### Why RFT can't move earlier
RFT needs a candidate generator that's already on-task β i.e., capable
of producing T3-shaped enhancer edits. The base Qwen3.5-2B can't do
that without prior T3 SFT (it doesn't know the
`<enhancer_dna_start>` schema, the cell-type grammar, or the edit
budget convention). So RFT must come **after at least one T3 SFT pass**.
### Why current order beats "RFT after joint"
The Stage-3 adapter is trained on 35k T3-only rows. The Stage-4 joint
adapter sees only 35k T3 rows out of a 105k mix (33% T3 share). For
candidate-generation specifically:
* **Stage-3 adapter**: more T3-faithful, fewer cross-task style
artefacts. Higher per-row keep-rate when scoring against the T3
oracle.
* **Stage-4 adapter**: better generalist, slightly weaker T3 grammar.
Empirical question β the **ablation** to run:
> RFT-from-Stage3 (current default) vs RFT-from-Stage4 (joint adapter)
> β does the joint adapter's regularisation produce candidates with
> higher mean objective margins, or do the format artefacts dominate?
This is a one-flag change: `--adapter-state-dict
runs/exp_joint_multitask_*/final/pytorch_model.bin` instead of the
T3-only path. Costs one RFT pass + one re-train. Worth a Table 3 row.
### Why a second RFT pass after joint isn't free
Re-RFT-ing the joint adapter would chain:
```
Stage 4 β Stage 4-RFT β retrain joint β ...
```
Each step is multi-hour on H100. Diminishing returns are likely (the
oracle is the same, the candidate distribution doesn't gain new
biology after the second pass). Defer to extended-paper revision.
## 2. T3-specific oracle?
**TL;DR**: We use the DeepSTARR-7cell oracle for T3 (and T1). A T3-
specific oracle is a paper-extension worth doing if reviewers push;
the minimum-publishable suite ships with the shared oracle.
### What a T3-specific oracle would predict
Two candidates:
* **"Edit-quality" oracle** β input: `(reference, edited, edit_distance,
cell_type, edit_type)`. Output: `(satisfies_objective, margin)`.
Trained on RFT-filtered (positive) + RFT-rejected (negative)
candidates. Self-bootstrap risk: the oracle becomes circular if it
only learns what the previous oracle's filter did.
* **"Pairing" oracle (transferred from T2)** β input:
`(promoter, edited_enhancer, cell_type)`. Output: `pairing_score`.
Trained on observed pairs (T2 dataset). For T3 we'd score `(promoter,
edited)` and reward the model when this score increases over
`(promoter, reference)`. This is biology-grounded and non-circular.
### Why we ship with DeepSTARR-7cell
The DeepSTARR-7cell oracle gives per-cell-type activity scores. For T3:
* `activity_boost` reduces to "did the activity in the source cell
go up?" β `pred_activity_src > ref_activity_src`. Direct read from
oracle.
* `cell_type_transfer` reduces to "did the activity shift toward target?"
β `(pred_tgt - pred_src) - (ref_tgt - ref_src) > 0`. Direct read.
* `promoter_retarget` reduces to "is the new motif present?" β IUPAC
scan, no oracle needed.
So **all three T3 objectives are computable from the existing oracle**
without a T3-specific one. The DeepSTARR-7cell oracle is weak
(`val_pearson_mean=0.136`) in absolute terms, but the metrics use
**deltas and shifts** β relative ranking, where weak oracles still
carry meaningful signal.
If the paper review pushes back on oracle weakness, the "T2-pairing-as-
T3-oracle" path is the right extension β concrete and publishable.
## 3. Loop-SFT β does it need T3-aware changes?
**TL;DR**: No code change. The data source for T3 trajectories should
swap to the post-RFT JSONL, but Loop-SFT itself is task-agnostic.
### What Loop-SFT does
`scripts/train_loop_sft.py` consumes JSONL records of the shape:
```
{"id": ..., "task_type": ...,
"messages": [system, user],
"trajectory": {"steps": [{kind, state, text, tool_name, tool_args, tool_result}, ...]}}
```
The collator renders the trajectory into a single assistant turn the
LLM is trained to emit. Task-type doesn't change anything β the
collator looks at trajectory steps, not task semantics.
### What changes for T3
The **trajectory dataset** for T3 is currently expanded from the
heuristic-gold T3 JSONL via `scripts/expand_loop_trajectories.py`. The
trajectory's final `kind="final"` step contains the gold enhancer
sequence β currently the heuristic synthetic-edit, **not** the
oracle-validated RFT candidate.
To align with v5:
```bash
# old:
python scripts/expand_loop_trajectories.py \
--source data/prod_samples/train.enhancer_editing.strat7c.n35k.jsonl \
--out data/trajectories/train.enhancer_editing.jsonl
# new (post-RFT-aware):
python scripts/expand_loop_trajectories.py \
--source runs/exp_t3_fusion_sft_${STAMP}/rft_filtered_train.jsonl \
--out data/trajectories/train.enhancer_editing.rft.jsonl
```
Same script, different `--source`. The expander reads the assistant's
gold answer from `messages[-1]["content"]`, which RFT replaced with
the candidate. So the trajectory's `final` step inherits the
candidate. No code change.
For the paper, we report Loop-SFT trained on:
* Heuristic-gold trajectories (matches the current pipeline default)
* **Post-RFT trajectories** (new; aligned with v5 evaluation)
These give two T3 Loop-SFT rows in Table 1 β a clean ablation.
## 4. SV-GSPO β does it need T3-aware changes?
**TL;DR**: **YES, and it's now done.** The T3 outcome-reward function
in `regureasoner/rl/reward_shaper.py` was trained on the **wrong**
objective (sequence-distance window) before this commit. Fixed.
### What was wrong
The previous `outcome_enhancer_editing(final, gold, min_edit=1,
max_edit=60)` returned 1.0 for any edit distance in `[1, 60]` and 0
for identity. That tells the agent "make a moderately-sized edit" β
which is exactly the wrong signal under the v5 framework where the
metric of record is *objective satisfaction*, not edit-distance window.
### What's changed
`outcome_enhancer_editing(final, gold)` now returns the average of three
binary checks aligned with `scripts/eval_t3_oracle.py`:
* `within_budget` β Hamming β€ `gold["edit_budget"]` (or 5%-of-len(ref)
fallback when no budget).
* `length_preserved` β `len(pred) == len(ref)`.
* `target_motif_present` β IUPAC regex (fwd + revcomp) for
`gold["target_motif"]`.
Score β {0, 1/3, 2/3, 1}. SV-GSPO group-normalisation handles the
discreteness; rollouts are expected to land between 0 and 1 with most
mass at 2/3 (budget + length usually pass; motif is the bottleneck).
### What's NOT done (deliberate)
The activity-based objectives (`activity_delta_src`,
`activity_relative_shift`) require an oracle forward-pass per rollout
β too slow for the hot RL loop. The evaluator-backed scorer
infrastructure in `reward_shaper.py` (the `OUTCOME_SCORERS_EVAL`
hook) is the right place to wire that in for **offline rescoring**
of completed rollouts; we'd then either include it as a `--reward-mode
oracle` flag or report both as a sanity check.
If reviewers push for "RL with the actual headline metric" we can:
1. Add an `OutcomeFnOracle` that takes the oracle as a closure and
computes `activity_delta_src` per rollout.
2. Cache oracle outputs by `(seq, cell_type)` hash to amortise repeat
eval across rollouts.
Estimated cost: ~50 LoC + a slow-but-correct ablation row.
## 5. External baselines β how do we compare against prior work?
**TL;DR**: Currently weak. Adding two strong external baselines (TACO
+ HyenaDNA / NT-v2) would harden the paper.
### What we have (internal)
* Zero-shot Qwen3.5-2B (raw + tool-enriched prompts)
* Fusion-SFT (NTv3-650M + LLM + cell context, our architecture)
* Loop-SFT (trajectory-augmented Fusion-SFT)
* SV-GSPO (RL on top of Loop-SFT)
* NTv3-MDLM T1, NTv3-direct T2 (no-LLM baselines, our architecture
but no language model)
### What's missing (external SOTA)
| Model | Task fit | Comparable for | Effort to add | Priority |
|---|---|---|---|---|
| **TACO** (Lin et al. NeurIPS 2024) | T3 native | T3 (paper-precedent) | Medium β repo public; needs DeepSTARR oracle re-fit | **HIGH** |
| **HyenaDNA** (Nguyen et al. NeurIPS 2023) | T1 / T2 | T1 generation, T2 binary classification | Low β already wired as encoder; needs head training only | **HIGH** |
| **DNABERT-2 / NT-v2** | T1 / T2 / T3 | All three (small encoder baseline) | Low β `regureasoner_loop` has NT-v2 wired | MEDIUM |
| **CtrlDNA** | T1 conditional generation | T1 only | Medium β repo public, training data alignment needed | MEDIUM |
| **Evo / Evo2** | Generation, fluency | T1 (but they're 7B+, hard to run on H100) | High β vortex install on lab cluster | LOW |
| **Caduceus** | DNA encoder | Same as NT-v2; redundant | Low | LOW |
| **DeepSTARR** (predictor) | Activity prediction | Used as our oracle, NOT a baseline for our tasks | N/A | N/A |
**Recommended for the minimum-publishable submission**: add TACO (T3
paper precedent) + HyenaDNA (T2 fluency baseline). DNABERT-2 is
nice-to-have. The rest go into the extended-paper version.
The exact recipe:
* TACO: clone their repo, drop in our DeepSTARR-7cell oracle, run
their trainer on our T3 train split. Eval with our
`eval_t3_oracle.py`. Apples-to-apples.
* HyenaDNA: their HF model card already has T2-style heads. Wire as
a `--encoder hyenadna` option in `run_genomefm_benchmark.py` and
retrain the pair head.
Both ~1 day of work each.
## 6. Pipeline state β does any in-flight job need modification?
**No.** Audit:
* **Bench grid (in flight)** β vLLM zero-shot inference. T3 zs eval
uses the heuristic gold's metadata only (target_motif, edit_budget)
β no leakage from the v5 framework changes. Safe to let finish.
* **post_bench_pipeline.sh** β already updated with multi-turn RFT
(commit `25504fd`) and Stage 3d post-RFT reasoning (commit `3e65c96`).
Will pick up the changes when it auto-fires.
* **No fusion-SFT job is currently running.** Stages 1β4 fire only
after the bench grid finishes.
The `outcome_enhancer_editing` reward fix lands in `regureasoner/rl/
reward_shaper.py` β used only by SV-GSPO (Stage C, not yet running on
H100). Lab-side SV-GSPO runs would need to pull this commit.
## 7. Concrete addition to the experiment suite
These lines belong in [`docs/minimal_publishable_suite.md`](minimal_publishable_suite.md)
once the H100 clears its current backlog:
```bash
# Phase B-T3-RFT-from-joint (new ablation; 6h)
STAGE_4_FINAL=runs/exp_joint_multitask_${STAMP}/final/pytorch_model.bin
pixi run python scripts/rft_t3.py \
--adapter-state-dict $STAGE_4_FINAL \
--train-jsonl data/prod_samples/train.enhancer_editing.strat7c.n35k.jsonl \
--oracle-path runs/exp_oracle_ds_7cell_min/oracle.pt \
--output-jsonl runs/exp_t3_rft_from_joint_${STAMP}/rft_filtered_train.jsonl \
--candidates 4 --rounds 4 --temp-ramp 0.15
# Phase B-T3-Loop-SFT-on-RFT (new ablation; 4h)
pixi run python scripts/expand_loop_trajectories.py \
--source runs/exp_t3_fusion_sft_${STAMP}/rft_filtered_train.jsonl \
--out data/trajectories/train.enhancer_editing.rft.jsonl
TASK=enhancer_editing \
TRAIN_JSONL=data/trajectories/train.enhancer_editing.rft.jsonl \
... \
bash slurm/run_train_loop_sft.sh
```
## TL;DR for the paper story
* **T3 dataset regen**: train labels, yes via RFT (in the pipeline).
Test labels, no β eval ignores the heuristic gold.
* **T3 benchmark code**: clean. `eval_t3_oracle.py` is the new headline
scorer; old `genqual.json` argmax_acc is informative-only.
* **T3 reward shaper**: fixed in this commit. SV-GSPO will now
optimise the right objective.
* **Loop-SFT**: no code change; just point at the post-RFT JSONL.
* **External baselines**: TACO + HyenaDNA are the two we should add
before submission. Both ~1 day each.
|