dnathinker-checkpoints / docs /experiment_chain_v5_unified.md
explcre's picture
Upload docs/experiment_chain_v5_unified.md with huggingface_hub
bc05167 verified

Experiment chain β€” unified-MM LLM (paper-grade, v5)

Single document that ties together every .py and .sh we run from zero-shot bench to final SV-GSPO checkpoint. v4 (in experiment_chain.md) covered the per-task LLM progression. v5 adds the unified-multimodal stack and the post-bench training pipeline that auto-fires after the bench grid finishes.

Run order is the same as the order of stages in /dev/shm/dnathinker/post_bench_pipeline.sh β€” the H100 just reads that script top-to-bottom, no SLURM dependencies.

0. Bench grid (zero-shot baselines)

Stage Script Output Purpose
ZS-T1 raw scripts/run_llm_benchmark_vllm.py --task enhancer_generation --prompt raw runs/exp_t1_grid_*/zs_raw/{predictions,metrics}.json{,l} Paper Table 1 row 1
ZS-T1 enriched same w/ --prompt enriched runs/exp_t1_grid_*/zs_enriched/... Table 1 row 2
ZS-T2 raw --task pair_prediction --prompt raw runs/exp_t2_grid_*/zs_raw/... Table 1 row 1 (T2)
ZS-T2 enriched same enriched runs/exp_t2_grid_*/zs_enriched/... Table 1 row 2 (T2)
ZS-T3 raw / enriched --task enhancer_editing Γ— {raw, enriched} runs/exp_t3_grid_*/... Table 1 rows 1–2 (T3)

Driver: /dev/shm/dnathinker/launch_bench_vllm.sh runs the 6 vLLM benches sequentially. When the orchestrator PID exits, an attached watcher fires post_bench_pipeline.sh.

1. Post-bench pipeline (auto-triggered)

/dev/shm/dnathinker/post_bench_pipeline.sh. Each stage skip-checks on its own output file, so re-runs are idempotent.

Stage 0 β€” ZS scoring (early HF push)

  • scripts/run_generation_eval.py β†’ genqual.json (FBD / spec / argmax-acc / per-cell-type) for T1+T3 zs_raw / zs_enriched.
  • scripts/eval_t3_oracle.py β†’ genqual_t3_oracle.json (within-budget, length-preserved, objective-success per edit_type, per-cell-type) on T3 zs predictions.
  • HF push of the partial bench results so lab can see numbers before training stages finish.

Stages 1–4 β€” Fusion-SFT family (the headline)

Each run_fusion call invokes scripts/train_fusion_sft.py with --architecture-mode llava, then Stage Nb invokes scripts/predict_fusion.py on the trained adapter to get predictions on the full test set, followed by run_generation_eval.py (T1/T3) and eval_t3_oracle.py (T3 only). These produce the lora_raw / lora_enriched rows in Table 1.

Stage Train script call Inference + scoring Paper row
1 T1 fusion-SFT (n35k T1) score_adapter T1 ... raw / enriched T1 row 4
2 T2 fusion-SFT (n35k T2 balanced) score_adapter T2 ... raw / enriched T2 row 4
3 T3 fusion-SFT (n35k T3, heuristic gold) score_adapter T3 ... raw / enriched T3 row 4a
3b T3 reasoning-only SFT (--mask-assistant-dna-span) same T3 row 4b β€” paper ablation
3c T3 RFT (Stage A β†’ K candidates β†’ oracle-filter β†’ re-SFT) same T3 row 4c β€” paper ablation
4 Joint multitask fusion-SFT (105k = 35kΓ—3 balanced) score_adapter Γ— {T1,T2,T3} Γ— {raw,enriched} headline row β€” one model, three tasks

score_adapter is defined inside post_bench_pipeline.sh. It exists because run_llm_benchmark.py --adapter-dir expects PEFT format (adapter_model.bin + adapter_config.json), and our FusionSFTTrainer saves a full OneShotFusionLM state_dict (LLM + LoRA + NTv3 projector + cell context encoder) via torch.save. predict_fusion.py rebuilds the model and load_state_dicts it, then runs model.llm.generate with the same prompt builder + parser that ZeroShotLLM.predict uses, so predictions.jsonl is shape- compatible with the genqual + T3-oracle scorers. This is the single bridge between training output and the eval pipeline.

Stages 5–6 β€” NTv3-only baselines

  • Stage 5: scripts/train_generation.py --head mdlm (NTv3-MDLM on T1).
  • Stage 6: scripts/train_ntv3_direct.py (NTv3-direct on T2).
  • "no LLM" rows in Table 1 β€” proves the LLM contributes signal.

Stage 7 β€” Aggregator + final HF push

  • aggregate_results.py walks runs/, collapses (task, mode, prompt) and writes /dev/shm/dnathinker/results/h100_snapshot.md.
  • HF push of metrics + genqual + h100_snapshot.md.

2. Where Loop-SFT fits

Loop-SFT (scripts/train_loop_sft.py) is not redundant with RFT. The two filter on different signals:

  • RFT (Stage 3c): filter by output objective β€” generate K candidates, keep ones whose DNA sequence satisfies budget + motif
    • activity-shift via the oracle. Improves the final answer.
  • Loop-SFT: filter by trajectory β€” keep traces whose intermediate tool calls and reasoning chain are correct. Improves the reasoning chain that leads to the answer.

The full T3 stack the paper aims for:

Fusion-SFT (heuristic)  β†’  Loop-SFT (trajectory-filtered)  β†’  RFT (oracle-filtered)  β†’  SV-GSPO (RL)
       Stage A                       Stage A'                       Stage B               Stage C

Stage A' (Loop-SFT) is deferred to a follow-up run because the trajectory-trace dataset (16K v9 in t3_evaluation_design.md Β§10) is the lab's, not the H100's. The H100 ships:

  • Stage A (the three run_fusion calls)
  • Stage A's reasoning-only ablation (3b) β€” equivalent to a cold-start Loop-SFT with no traces; an ablation that shows losing the heuristic DNA target doesn't tank the model
  • Stage B (RFT, 3c)

When the lab finishes Loop-SFT on its side, the chain re-merges: both teams point at the same exp_t3_fusion_sft_*/final/pytorch_model.bin, the lab adds Loop-SFT on top, the H100 adds RFT on top, and we pick whichever path scores higher on eval_t3_oracle.py for the paper.

3. Job map (current state, 2026-04-27 UTC)

H100 NVL
β”œβ”€β”€ PID 100474  launch_bench_vllm.sh (orchestrator)
β”‚   └── PID 121129  vLLM bench T2 zs_enriched (in flight)
β”‚       queued: T3 zs_raw, T3 zs_enriched
└── PID 100544  watcher β†’ post_bench_pipeline.sh (idle until 100474 exits)

ETAs (rough, post-T2 enriched completion):

  • T3 raw + T3 enriched bench: ~5h each (10h total)
  • Stage 0 + 0c (genqual + T3 oracle on zs preds): ~30 min
  • Stages 1–3 fusion-SFT (3 Γ— 35k Γ— 1 epoch on H100 NVL): ~6–8h total
  • Stage 3b reasoning-only: ~3h
  • Stage 3c RFT generate + filter + re-SFT: ~5h
  • Stage 4 joint multitask 105k: ~10h
  • Stages 5–6 NTv3-only: ~2h each
  • Stage 7 aggregator + HF push: minutes

Total post-bench β‰ˆ 40 H100-hours. Tracked in runs/post_bench_pipeline.log β€” tail -f for liveness.

4. Paper-table β†’ script map (cheat sheet)

Table 1 row Numbers come from Per-cell breakdown?
Row 1 (zs_raw) runs/exp_t{1,2,3}_grid_*/zs_raw/genqual/genqual.json yes
Row 2 (zs_enriched) .../zs_enriched/genqual/genqual.json yes
Row 3 (LoRA, no NTv3) DEFERRED β€” not in current pipeline
Row 4 (Fusion-SFT, per-task) runs/exp_t{1,2,3}_fusion_sft_*/predict_t{1,2,3}_{raw,enriched}/genqual/*.json yes (T1/T3); T2 has no per-cell β€” pair_prediction is binary
Row 4b (T3 reasoning-only) runs/exp_t3_fusion_sft_reasonly_*/predict_t3_*/genqual/... yes
Row 4c (T3 RFT) runs/exp_t3_fusion_sft_rft_*/predict_t3_*/genqual/... yes
Headline (joint multitask) runs/exp_joint_multitask_*/predict_t{1,2,3}_*/genqual/... yes
Row 5 (Loop-SFT) lab side, slurm
Row 6 (SV-GSPO) lab side, slurm

T3-specific paper section uses the objective-satisfaction metrics from eval_t3_oracle.py (within_budget, length_preserved, objective_success_*, transfer_specificity, in_budget_at_{5,10,20}pct), not the heuristic-overlap genqual ones β€” see t3_evaluation_design.md Β§2 for why.

5. Reasoning-trace augmentation (OpenRouter / Nemotron, free)

scripts/build_reasoning_traces.py rewrites the assistant turn in any T1/T2/T3 SFT JSONL to include a single-shot rationale that wires the enriched evidence (TFBS scan, expression context, motif hits) to the gold answer. Output schema matches the parent project's existing pe_dataset_reasoning_expansion_*/jsonl/ files exactly:

<reasoning_start>RATIONALE</reasoning_end>
<enhancer_dna_start>SEQ</enhancer_dna_end>     # T1/T3
<pair_label>paired|not_paired</pair_label>     # T2

Reuses regureasoner.loop.openrouter.OpenRouterClient (same retry + backoff client expand_loop_trajectories.py uses). Single API call per row β€” the teacher only writes the justification, not the answer, so small free-tier models (default nvidia/nemotron-nano-9b-v2:free; switch to nvidia/llama-3.1-nemotron-70b-instruct:free for richer rationales) stay reliable.

Resumable: appends to the output JSONL; on startup it scans every id already present and skips those rows in the source. Daily reruns accumulate without overlap.

Budget: --max-requests (default 1000) is the per-invocation cap. OpenRouter free tier = 1000 req/day per key. Multiple keys can shard line-level via --shard-index/--num-shards.

Daily-loop launcher: slurm/build_reasoning_traces_loop.sh β€” sources OPENROUTER_API_KEY from /dev/shm/dnathinker/.env, walks T1/T2/T3 with PER_TASK=333 each (β‰ˆ1000/day total), and optionally --daemons into a 24h sleep loop. Zero GPU; runs alongside any training stage.

SFT integration: when β‰₯N augmented rows accumulate per task, point scripts/train_fusion_sft.py --train-jsonl at /dev/shm/dnathinker/data/reasoning_traces/train.<task>.reasoning.jsonl. Same collator, same trainer β€” the only difference is the assistant turn now starts with <reasoning_start>...</reasoning_end>, so the trained model emits explicit rationale + answer at inference time. This is the paper's "reasoning model" row in T3's table; the non-reasoning fusion-SFT runs (Stages 1–3) stay as the no-rationale comparison.

Per-task source JSONL β€” what the teacher justifies:

Task Source JSONL Why
T1 train.enhancer_generation.strat7c.n35k.jsonl (heuristic gold) The heuristic gold is the empirical paired enhancer; teacher justifies why it pairs in this cell type.
T2 train.pair_prediction.strat7c.n35k.jsonl (observed positive + pseudo-negative) Teacher justifies the binary label using shared-TFBS / GC / expression evidence.
T3 post-RFT train.t3_rft.jsonl The heuristic gold for T3 is a synthetic motif-implant (not unique GT β€” see t3_evaluation_design.md Β§1). RFT (Stage 3c) replaces it with an oracle-validated candidate. Reasoning expansion must run on the post-RFT JSONL so the rationale justifies a sequence the oracle has actually scored, not the heuristic. Order: Fusion-SFT β†’ RFT β†’ reasoning expansion β†’ reasoning-augmented Fusion-SFT.

The launcher slurm/build_reasoning_traces_loop.sh defaults to the heuristic-gold JSONLs for T1/T2 and the heuristic-gold for T3, but override T3_SRC=/dev/shm/dnathinker/runs/exp_t3_fusion_sft_rft_${STAMP}/.../train.t3_rft.jsonl once Stage 3c finishes β€” the loop's resume logic handles a mid-run source swap because the augmented output JSONL keeps row ids.

6. Input sanitisation β€” applied globally before any model sees text

regureasoner/utils/input_sanitize.py (used by PromptBuilder.user() and build_reasoning_traces._format_user) strips three classes of issue at read-time, so we don't need to regenerate the prod JSONLs:

  1. Label leaks β€” peak_name=chr…, enhancer_peak_name=chr…, the "Peak coordinates parsed to chr…:…" sentence, the "Observed dataset row is a released paired/not_paired link …" sentence (T2's biggest leak), and label_source=… lines.
  2. Unexplained proxy scores β€” Evolution proxy score … (expression_stability_proxy_v1), promoter_likeness_score=…, quality_score / repeat_fraction / kmer_entropy_norm (these are ad-hoc internal scores the model can't ground; we omit rather than try to explain in-prompt).
  3. Cell-type abbreviations expanded β€” cell_type=Ex β†’ cell_type=Excitatory neuron (Ex) so the model knows the biology.

Applied before any model call. Idempotent β€” running it twice yields the same string. 12 unit tests cover every leak/score family + idempotency + cell-type expansion (tests/test_input_sanitize.py).

Why we don't run this inside post_bench_pipeline.sh: the script is IO-bound (no GPU), capped at 1000 req/day, and meant to run for multiple days in the background. Putting it in the GPU pipeline would either waste a single 1000-call day or block the rest of the pipeline waiting for accumulation. The right pattern is to launch build_reasoning_traces_loop.sh --daemon once at the start of the campaign and let it accumulate rows independently. When a critical mass exists, fire a single fusion-SFT run on the augmented JSONL.