dnathinker-checkpoints / docs /t3_post_v5_followups.md
explcre's picture
Upload docs/t3_post_v5_followups.md with huggingface_hub
aeb566f verified

T3 β€” post-v5 follow-ups (deep-think + work plan)

The first-pass v5 work (oracle eval, multi-turn RFT, reasoning expansion, sanitiser) is shipped. This doc captures the deeper questions the T3 rethink raises, what's already been done about them, and what's deliberately deferred.

1. RFT timing β€” when, relative to T1/T2 + joint training?

TL;DR: Current order (per-task SFT β†’ T3 RFT β†’ joint multitask) is the right minimum. A second RFT pass after the joint adapter is a worthwhile ablation but adds a serial dependency.

The current order

Stage 1   T1 fusion-SFT (heuristic gold)
Stage 2   T2 fusion-SFT
Stage 3   T3 fusion-SFT (heuristic synthetic-edit gold)
Stage 3b  T3 reasoning-only SFT (paper ablation)
Stage 3c  RFT β€” generate K candidates from Stage-3 adapter,
          oracle-filter, retrain on filtered set
Stage 4   Joint multitask fusion-SFT (T1+T2+T3, 35k each, balanced)

Why RFT can't move earlier

RFT needs a candidate generator that's already on-task β€” i.e., capable of producing T3-shaped enhancer edits. The base Qwen3.5-2B can't do that without prior T3 SFT (it doesn't know the <enhancer_dna_start> schema, the cell-type grammar, or the edit budget convention). So RFT must come after at least one T3 SFT pass.

Why current order beats "RFT after joint"

The Stage-3 adapter is trained on 35k T3-only rows. The Stage-4 joint adapter sees only 35k T3 rows out of a 105k mix (33% T3 share). For candidate-generation specifically:

  • Stage-3 adapter: more T3-faithful, fewer cross-task style artefacts. Higher per-row keep-rate when scoring against the T3 oracle.
  • Stage-4 adapter: better generalist, slightly weaker T3 grammar.

Empirical question β€” the ablation to run:

RFT-from-Stage3 (current default) vs RFT-from-Stage4 (joint adapter) β€” does the joint adapter's regularisation produce candidates with higher mean objective margins, or do the format artefacts dominate?

This is a one-flag change: --adapter-state-dict runs/exp_joint_multitask_*/final/pytorch_model.bin instead of the T3-only path. Costs one RFT pass + one re-train. Worth a Table 3 row.

Why a second RFT pass after joint isn't free

Re-RFT-ing the joint adapter would chain:

Stage 4 β†’ Stage 4-RFT β†’ retrain joint β†’ ...

Each step is multi-hour on H100. Diminishing returns are likely (the oracle is the same, the candidate distribution doesn't gain new biology after the second pass). Defer to extended-paper revision.

2. T3-specific oracle?

TL;DR: We use the DeepSTARR-7cell oracle for T3 (and T1). A T3- specific oracle is a paper-extension worth doing if reviewers push; the minimum-publishable suite ships with the shared oracle.

What a T3-specific oracle would predict

Two candidates:

  • "Edit-quality" oracle β€” input: (reference, edited, edit_distance, cell_type, edit_type). Output: (satisfies_objective, margin). Trained on RFT-filtered (positive) + RFT-rejected (negative) candidates. Self-bootstrap risk: the oracle becomes circular if it only learns what the previous oracle's filter did.

  • "Pairing" oracle (transferred from T2) β€” input: (promoter, edited_enhancer, cell_type). Output: pairing_score. Trained on observed pairs (T2 dataset). For T3 we'd score (promoter, edited) and reward the model when this score increases over (promoter, reference). This is biology-grounded and non-circular.

Why we ship with DeepSTARR-7cell

The DeepSTARR-7cell oracle gives per-cell-type activity scores. For T3:

  • activity_boost reduces to "did the activity in the source cell go up?" β€” pred_activity_src > ref_activity_src. Direct read from oracle.
  • cell_type_transfer reduces to "did the activity shift toward target?" β€” (pred_tgt - pred_src) - (ref_tgt - ref_src) > 0. Direct read.
  • promoter_retarget reduces to "is the new motif present?" β€” IUPAC scan, no oracle needed.

So all three T3 objectives are computable from the existing oracle without a T3-specific one. The DeepSTARR-7cell oracle is weak (val_pearson_mean=0.136) in absolute terms, but the metrics use deltas and shifts β€” relative ranking, where weak oracles still carry meaningful signal.

If the paper review pushes back on oracle weakness, the "T2-pairing-as- T3-oracle" path is the right extension β€” concrete and publishable.

3. Loop-SFT β€” does it need T3-aware changes?

TL;DR: No code change. The data source for T3 trajectories should swap to the post-RFT JSONL, but Loop-SFT itself is task-agnostic.

What Loop-SFT does

scripts/train_loop_sft.py consumes JSONL records of the shape:

{"id": ..., "task_type": ...,
 "messages": [system, user],
 "trajectory": {"steps": [{kind, state, text, tool_name, tool_args, tool_result}, ...]}}

The collator renders the trajectory into a single assistant turn the LLM is trained to emit. Task-type doesn't change anything β€” the collator looks at trajectory steps, not task semantics.

What changes for T3

The trajectory dataset for T3 is currently expanded from the heuristic-gold T3 JSONL via scripts/expand_loop_trajectories.py. The trajectory's final kind="final" step contains the gold enhancer sequence β€” currently the heuristic synthetic-edit, not the oracle-validated RFT candidate.

To align with v5:

# old:
python scripts/expand_loop_trajectories.py \
    --source data/prod_samples/train.enhancer_editing.strat7c.n35k.jsonl \
    --out    data/trajectories/train.enhancer_editing.jsonl

# new (post-RFT-aware):
python scripts/expand_loop_trajectories.py \
    --source runs/exp_t3_fusion_sft_${STAMP}/rft_filtered_train.jsonl \
    --out    data/trajectories/train.enhancer_editing.rft.jsonl

Same script, different --source. The expander reads the assistant's gold answer from messages[-1]["content"], which RFT replaced with the candidate. So the trajectory's final step inherits the candidate. No code change.

For the paper, we report Loop-SFT trained on:

  • Heuristic-gold trajectories (matches the current pipeline default)
  • Post-RFT trajectories (new; aligned with v5 evaluation)

These give two T3 Loop-SFT rows in Table 1 β€” a clean ablation.

4. SV-GSPO β€” does it need T3-aware changes?

TL;DR: YES, and it's now done. The T3 outcome-reward function in regureasoner/rl/reward_shaper.py was trained on the wrong objective (sequence-distance window) before this commit. Fixed.

What was wrong

The previous outcome_enhancer_editing(final, gold, min_edit=1, max_edit=60) returned 1.0 for any edit distance in [1, 60] and 0 for identity. That tells the agent "make a moderately-sized edit" β€” which is exactly the wrong signal under the v5 framework where the metric of record is objective satisfaction, not edit-distance window.

What's changed

outcome_enhancer_editing(final, gold) now returns the average of three binary checks aligned with scripts/eval_t3_oracle.py:

  • within_budget β€” Hamming ≀ gold["edit_budget"] (or 5%-of-len(ref) fallback when no budget).
  • length_preserved β€” len(pred) == len(ref).
  • target_motif_present β€” IUPAC regex (fwd + revcomp) for gold["target_motif"].

Score ∈ {0, 1/3, 2/3, 1}. SV-GSPO group-normalisation handles the discreteness; rollouts are expected to land between 0 and 1 with most mass at 2/3 (budget + length usually pass; motif is the bottleneck).

What's NOT done (deliberate)

The activity-based objectives (activity_delta_src, activity_relative_shift) require an oracle forward-pass per rollout β€” too slow for the hot RL loop. The evaluator-backed scorer infrastructure in reward_shaper.py (the OUTCOME_SCORERS_EVAL hook) is the right place to wire that in for offline rescoring of completed rollouts; we'd then either include it as a --reward-mode oracle flag or report both as a sanity check.

If reviewers push for "RL with the actual headline metric" we can:

  1. Add an OutcomeFnOracle that takes the oracle as a closure and computes activity_delta_src per rollout.
  2. Cache oracle outputs by (seq, cell_type) hash to amortise repeat eval across rollouts.

Estimated cost: ~50 LoC + a slow-but-correct ablation row.

5. External baselines β€” how do we compare against prior work?

TL;DR: Currently weak. Adding two strong external baselines (TACO

  • HyenaDNA / NT-v2) would harden the paper.

What we have (internal)

  • Zero-shot Qwen3.5-2B (raw + tool-enriched prompts)
  • Fusion-SFT (NTv3-650M + LLM + cell context, our architecture)
  • Loop-SFT (trajectory-augmented Fusion-SFT)
  • SV-GSPO (RL on top of Loop-SFT)
  • NTv3-MDLM T1, NTv3-direct T2 (no-LLM baselines, our architecture but no language model)

What's missing (external SOTA)

Model Task fit Comparable for Effort to add Priority
TACO (Lin et al. NeurIPS 2024) T3 native T3 (paper-precedent) Medium β€” repo public; needs DeepSTARR oracle re-fit HIGH
HyenaDNA (Nguyen et al. NeurIPS 2023) T1 / T2 T1 generation, T2 binary classification Low β€” already wired as encoder; needs head training only HIGH
DNABERT-2 / NT-v2 T1 / T2 / T3 All three (small encoder baseline) Low β€” regureasoner_loop has NT-v2 wired MEDIUM
CtrlDNA T1 conditional generation T1 only Medium β€” repo public, training data alignment needed MEDIUM
Evo / Evo2 Generation, fluency T1 (but they're 7B+, hard to run on H100) High β€” vortex install on lab cluster LOW
Caduceus DNA encoder Same as NT-v2; redundant Low LOW
DeepSTARR (predictor) Activity prediction Used as our oracle, NOT a baseline for our tasks N/A N/A

Recommended for the minimum-publishable submission: add TACO (T3 paper precedent) + HyenaDNA (T2 fluency baseline). DNABERT-2 is nice-to-have. The rest go into the extended-paper version.

The exact recipe:

  • TACO: clone their repo, drop in our DeepSTARR-7cell oracle, run their trainer on our T3 train split. Eval with our eval_t3_oracle.py. Apples-to-apples.
  • HyenaDNA: their HF model card already has T2-style heads. Wire as a --encoder hyenadna option in run_genomefm_benchmark.py and retrain the pair head.

Both ~1 day of work each.

6. Pipeline state β€” does any in-flight job need modification?

No. Audit:

  • Bench grid (in flight) β€” vLLM zero-shot inference. T3 zs eval uses the heuristic gold's metadata only (target_motif, edit_budget) β€” no leakage from the v5 framework changes. Safe to let finish.
  • post_bench_pipeline.sh β€” already updated with multi-turn RFT (commit 25504fd) and Stage 3d post-RFT reasoning (commit 3e65c96). Will pick up the changes when it auto-fires.
  • No fusion-SFT job is currently running. Stages 1–4 fire only after the bench grid finishes.

The outcome_enhancer_editing reward fix lands in regureasoner/rl/ reward_shaper.py β€” used only by SV-GSPO (Stage C, not yet running on H100). Lab-side SV-GSPO runs would need to pull this commit.

7. Concrete addition to the experiment suite

These lines belong in docs/minimal_publishable_suite.md once the H100 clears its current backlog:

# Phase B-T3-RFT-from-joint  (new ablation; 6h)
STAGE_4_FINAL=runs/exp_joint_multitask_${STAMP}/final/pytorch_model.bin
pixi run python scripts/rft_t3.py \
    --adapter-state-dict $STAGE_4_FINAL \
    --train-jsonl data/prod_samples/train.enhancer_editing.strat7c.n35k.jsonl \
    --oracle-path runs/exp_oracle_ds_7cell_min/oracle.pt \
    --output-jsonl runs/exp_t3_rft_from_joint_${STAMP}/rft_filtered_train.jsonl \
    --candidates 4 --rounds 4 --temp-ramp 0.15

# Phase B-T3-Loop-SFT-on-RFT  (new ablation; 4h)
pixi run python scripts/expand_loop_trajectories.py \
    --source runs/exp_t3_fusion_sft_${STAMP}/rft_filtered_train.jsonl \
    --out    data/trajectories/train.enhancer_editing.rft.jsonl
TASK=enhancer_editing \
TRAIN_JSONL=data/trajectories/train.enhancer_editing.rft.jsonl \
... \
bash slurm/run_train_loop_sft.sh

TL;DR for the paper story

  • T3 dataset regen: train labels, yes via RFT (in the pipeline). Test labels, no β€” eval ignores the heuristic gold.
  • T3 benchmark code: clean. eval_t3_oracle.py is the new headline scorer; old genqual.json argmax_acc is informative-only.
  • T3 reward shaper: fixed in this commit. SV-GSPO will now optimise the right objective.
  • Loop-SFT: no code change; just point at the post-RFT JSONL.
  • External baselines: TACO + HyenaDNA are the two we should add before submission. Both ~1 day each.