readctrl / code /RL_model /verl /verl_train /docs /advance /async-on-policy-distill.md
shahidul034's picture
Add files using upload-large-folder tool
ff8fd11 verified

Recipe: Async On-Policy Knowledge Distillation Trainer

Authors: Brilliant Hanabi, furunding

Last updated: 2025-11-08

1. Background

On-policy knowledge distillation (KD) trains a student policy to imitate a stronger teacher using samples drawn from the student's current policy. For each on-policy rollout the teacher returns soft, top-k token distributions and the student is optimized with a token-wise sparse KL objective that focuses learning on the teacher's high-probability modes. Because training examples come from the student's own state distribution, KD reduces distributional mismatch relative to off-policy distillation or supervised fine-tuning (SFT), improving stability and sample efficiency. Compared with reinforcement learning, KD avoids high-variance reward-based optimization and complex reward design by providing dense, informative per-token targets, which typically yields faster convergence and simpler scaling. Recent empirical and implementation-focused writeups (e.g., ThinkingMachines' blog on on-policy distillation) also demonstrate that on-policy distillation can deliver high-quality behavior with substantially lower compute and data requirements than many alternative approaches.

Built on verl’s Ray-based single-controller components, we initially assembled a strictly on-policy KD pipeline where rollout generation, teacher knowledge acquisition, and policy optimization ran in lockstep. In practice, this synchronous design proved highly inefficient: the three stages had to wait for one another, creating pipeline bubbles and underutilized GPUs. To address this, we extend the asynchronous schedulers introduced by the One-Step-Off Policy pipeline to overlap these phases. This overlap preserves the same distillation objective while trading some strict on-policy guarantees for substantial gains in end-to-end throughput and hardware utilization.

2. Distillation Overview and Objective

This recipe centers on on-policy knowledge distillation: the student policy learns from a stronger teacher on samples generated by the current policy (on-policy). For each input prompt, the student (actor) generates responses; the teacher provides top-k token distributions, and the student is trained to match them token-wise.

Core components:

  1. Teacher signal: top-k log-probabilities and token indices per valid token position.
  2. Student objective: sparse, token-level KL divergence between student logits and teacher top-k distribution.

Objective: encourage student probabilities $Q$ to cover teacher modes $P$ using token-wise $\mathrm{KL}(P,|,Q)$ computed on the teacher's top-k support.

3. Efficient System Design

3.1 Schedulers (One-Step / Two-Step Off-Policy)

The native (serial) on-policy distillation process is shown in the figure below.

Zero-Step-Off Scheduler

This recipe supports optional schedulers that overlap generation, teacher querying, and updates to improve throughput without changing the distillation objective.

3.1.1 One-Step-Off-Policy

One-Step-Off Scheduler

  • Warm-up: 2 steps.
  • Overlap pattern: rollout while actor update; weight sync while teacher retrieving.
  • Timing keys: sync_rollout_weights, wait_prev_gen, wait_prev_teacher.

3.1.2 Two-Step-Off-Policy

Two-Step-Off Scheduler

  • Warm-up: 3 steps.
  • Overlap pattern: rollout, actor update while teacher retrieving; interleave weight sync.
  • Timing keys: sync_rollout_weights, max(wait_prev_gen, wait_prev_prev_teacher).

Tip: Use two_step_off when teacher takes much more time than sync; one_step_off for simpler overlapping.

Practical details:

  • Inputs per batch: teacher_topk_logps, teacher_topk_indices, attention_mask (to select valid token positions).
  • Loss injection: last pipeline stage computes KL via a logits processor; earlier stages remain unchanged.
  • Optional dynamic micro-batching groups sequences by density to reduce padding overhead.

The pipeline:

  1. Actor parameters are synchronized to a rollout worker group (nccl broadcast) with a little bit latency.
  2. Rollout workers (vLLM-backed) generate sequences asynchronously (async_generate_sequences).
  3. Teacher client service (ZeroMQ based) returns top-k log-probabilities + token indices for each sequence (batched micro-requests), enabling KL-based guidance.
  4. Megatron actor performs a KL divergence computation between student logits and teacher top-k distributions (custom TP-aware kernel in megatron_kl_loss.py).
  5. Scheduling strategies (one_step_off_scheduler, two_step_off_scheduler) can overlap phases (optional for throughput):

3.2 Weights sync between actor and rollout

We initially followed the weight synchronization path from the One-Step-Off-Policy recipe (Ray collective broadcast across all actor and rollout ranks, plus Megatron-side allgather of parameter shards). In practice this became the dominant bottleneck, so we made three changes:

  1. Batch-and-bulk load on the rollout side: instead of streaming tensors one-by-one (in one-step-off-policy recipe), we stage a bundle of parameter tensors and issue a single batched load into the rollout engine. In our setup this reduced the weight-loading time by roughly 3Γ—.
  2. Batch-and-bulk broadcast between the actor and rollout: instead of streaming tensors one-by-one (in one-step-off-policy recipe), we stage a bundle of parameter tensors and issue a single batched broadcast between the actor and rollout workers.
  3. Replace allgather with gather-to-root in Megatron: parameter shards are gathered to actor rank 0 (rather than allgathered to everyone), and that root then serves as the single source for broadcasting to rollout ranks. On top of the previous change, 2 and 3 changes delivered an additional ~4Γ— speedup in the synchronization phase.

4. High-Level Data & Control Flow

Driver (TaskRunner)
  β”œβ”€ Initialize Ray, tokenizer, datasets, worker groups
  β”œβ”€ Build ResourcePoolManager (actor vs rollout GPU layouts)
  β”œβ”€ Trainer.fit()
      β”œβ”€ init_workers(): build actor + rollout groups, broadcast weight metadata, create nccl collective group
      β”œβ”€ continuous_iterator(): epochs β†’ batches
      β”œβ”€ scheduler (see Section 6)
        β€’ _async_gen_next_batch(): optional weight sync + non-blocking rollout
        β€’ _async_get_teacher_knowledge(): submit teacher requests, store future
        β”œβ”€ For each step:
            β€’ Sync rollout weights
            β€’ Retrieve (batch, gen_output, teacher_output) from futures
            β€’ Merge gen + teacher outputs β†’ DataProto
            β€’ Compute metrics (response length stats, timing, throughput)
            β€’ Update actor (forward_backward_batch + KL loss + optimizer step)
            β€’ (Optional) save checkpoint

Note: Schedulers are optional and explained later; the distillation objective is independent of how phases are overlapped.

5. Key Components

5.1 OnPolicyDistillTrainer (ray_trainer.py)

  • Creates GenerationBatchFuture objects holding rollout and (later) teacher futures.
  • Adds scheduling + teacher integration + modified metric emission (KL, timing, MFU).

5.2 Actor Worker (Megatron)

  • OnPolicyDistillActor.update_policy() orchestrates micro-batch forward/backward.
  • KL Loss injection via logits_processor during forward on pipeline last stage.

5.3 Rollout Worker (vLLM / SGLang)

  • Pure inference mode (init_model builds model; no optimizer).
  • async_generate_sequences returns a Ray future for overlapping.

5.4 Teacher Service (teacher/)

  • Proxy + worker architecture (ZMQ REQ/REP) for batched top-k retrieval.
  • TeacherClient.submit() returns a Future; aggregator composes micro-batches.
  • Configurable temperature, max tokens, only-response mode.

5.5 KL Loss (megatron_kl_loss.py)

  • Performs normalization & stable per-token probability construction across TP shards.
  • Gradient is (student_probs - teacher_sparse_probs) scaled by upstream grad.

6. Configuration Highlights (on_policy_distill_trainer.yaml)

Section Purpose Notable Keys
actor_rollout_ref.teacher Teacher server server_ip, server_port, n_server_workers
trainer Global training control total_epochs, save_freq, scheduler (one_step_off
rollout Resource split for rollout n_gpus_per_node, nnodes

Remember to set trainer.n_gpus_per_node, trainer.nnodes, rollout.n_gpus_per_node and rollout.nnodes to allocate GPU resources.

Dynamic Batch Size

Enable by:

actor_rollout_ref.actor.use_dynamic_bsz=True
actor_rollout_ref.actor.max_token_len=6000  # cap post-group token length

Improves utilization under variable sequence lengths.

Resource Guidelines

  • Actor pool: trainer.nnodes * trainer.n_gpus_per_node GPUs.
  • Rollout pool: rollout.nnodes * rollout.n_gpus_per_node GPUs.
  • Ensure teacher server capacity β‰ˆ n_server_workers to avoid stalls (monitor wait_prev_teacher).

7. Usage Examples

7.1 Launch Teacher Server

Before training process, you should have a teacher server to provide logp information.

We provide a toy teacher server example with vLLM. It needs telnet to check proxy status, and python command to run. So if you have not installed telnet, you can just delete these code in start_server.sh. And some OS use python3 rather than python, so you also need to modify it. Also you can change the port of teacher if you meet port conflict.

There are 3 arguments can be set for vllm backend --tp-size, --n-logprobs and --ckpt-path in start_server.sh / worker.py. You should set before you start server.

We also provide a toy multi-node teacher server. You can start the main node using start_server.sh and start the slave nodes using join_server.sh. Still remember to set args in join_server.sh, especially the $PROXY_IP and $PROXY_BACKEND_PORT of main node.

When training, student will automatically use the teacher's topk (n-logprobs) to set its own topk argument at line 83 of recipe/gkd/megatron_kl_loss.py, so you don't need to set student's topk argument.

cd recipe/gkd/teacher
bash start_server.sh
# Exports ports and launches proxy + worker (default vLLM backend)

Verify with:

telnet localhost 15555

7.2 Minimal Local (Megatron + vLLM) Run

python3 -m recipe.gkd.main_gkd \
  --config-path=recipe/gkd/config \
  --config-name=on_policy_distill_trainer \
  actor_rollout_ref.model.path=/path/to/MODEL \
  data.train_files=/path/to/train.parquet \
  trainer.total_epochs=2 \
  trainer.n_gpus_per_node=4 rollout.n_gpus_per_node=2 \
  actor_rollout_ref.teacher.server_ip=127.0.0.1 \
  actor_rollout_ref.teacher.server_port=15555 \
  trainer.scheduler=one_step_off

(Requires a running teacher server).

7.3 Ray Job Submission (Distilled 16B Example)

See run_moonlight_dsv3_training.sh for a full script including:

  • Dist ckpt path setup (dist_checkpointing_path)
  • Expert parallel sizing (EP / ETP)
  • Dynamic batch sizing
  • Two-step-off scheduling for deeper overlap.

Submit (after adjusting paths):

bash recipe/gkd/run_moonlight_dsv3_training.sh

8. Metrics & Monitoring

Emitted metrics include (prefixes may vary):

  • Timing: timing/wait_prev_gen, timing/sync_rollout_weights, timing/get_teacher_knowledge, timing/update_actor.
  • Sequence stats: response_seq_len/* (avg, max, min, counts).
  • Performance: perf/mfu/actor, perf/max_memory_allocated_gb, perf/cpu_memory_used_gb.
  • Distillation: actor/kl_loss, actor/grad_norm, actor/lr.

Interpretation Tips:

  • High wait_prev_teacher β†’ scale n_server_workers and allocate more teacher GPUs or reduce per-request batch size, or just use two_step_off.
  • High wait_prev_gen with uniform lengths β†’ allocate more rollout GPUs.
  • High sync_rollout_weights β†’ check NCCL env / network congestion and try to modify actor_rollout_ref.rollout.update_weights_bucket_megabytes.

9. Extensibility Notes

  • Add new schedulers by following interface returning (epoch, batch, gen_output, teacher_output, timing_dict).
  • Integrate different distillation signals (e.g., hidden states, intermediate reasoning tokens) by extending teacher_utils.get_teacher_knowledge and modifying logits_processor.

10. Functional Support Summary

Category Supported
Train engine Megatron
Rollout engine vLLM
Distillation signal Teacher top-k logprobs & indices
Scheduling one_step_off, two_step_off

11. Quick Checklist Before Running

  • Teacher server reachable (telnet <ip> <port>).
  • actor_rollout_ref.model.path contains the correct Megatron/HF config artifacts.
  • train_files points to a parquet dataset compatible with this recipe's dataset loader.
  • NCCL environment vars set (see config/runtime_env.yaml).

Feel free to open issues or PRs to extend scheduler variants, add new distillation objectives, or broaden engine support, and more improvement.