fix token merging

#3
by kashif HF Staff - opened
Files changed (1) hide show
  1. app/src/content/article.mdx +7 -7
app/src/content/article.mdx CHANGED
@@ -165,9 +165,9 @@ The first limitation we address is ULD's sequence alignment, which simply trunca
165
 
166
  This alignment error worsens as tokenization differences increase because a single mismatch at the start of a sequence can propagate and create a cascading semantic error throughout the text.
167
 
168
- Instead of truncating, our method identifies the token merges required to equalise the sequence lengths for both tokenizers. We then merge the logits at the corresponding positions by summing their log probabilities. This sum, which represents the log of the joint probability for the merged tokens, is then passed through a softmax.
169
 
170
- We perform the token merge through summing the log probabilities to leverage the autoregressive nature of LLM sampling. Following the example in Figure 3, we want to merge Hugging and Face into one token for the sequence in blue. Using the conditional probabilities and the product rule[^f2], we can merge the probabilities and guarantee sequence alignment regardless of tokenizer discrepancies in the sequence dimension.
171
 
172
  <Image
173
  src={sequenceAlignment}
@@ -439,8 +439,8 @@ accelerate launch \
439
  --temperature 1.0 \
440
  --top_p 0.95 \
441
  --top_k 0 \
442
- --max_new_tokens 2048 \
443
- --max_prompt_length 512 \
444
  --lmbda 0.25 \
445
  --beta 0.0 \
446
  --use_uld_loss \
@@ -460,7 +460,7 @@ accelerate launch \
460
  --trackio_project Qwen3-4B-GKD-Tulu \
461
  --seed 42 \
462
  --warmup_ratio 0.05 \
463
- --lr_scheduler_type cosine_with_min_lr
464
  ```
465
  </Accordion>
466
 
@@ -468,7 +468,7 @@ accelerate launch \
468
 
469
  In this post, we introduced General On-Policy Logit Distillation (GOLD), a new method that enables effective on-policy knowledge distillation between models, even when the teacher and student do not share the same tokenizer vocabulary. This overcomes a significant limitation of existing on-policy methods like GKD, which require matched tokenizers.
470
 
471
- GOLD builds upon the offline ULD method but extends it to the on-policy setting and, critically, addresses its two main weaknesses. First, we replace ULD's naive sequence truncation with a token-merging strategy that sums log probabilities of mismatched tokens. Second, we implement a hybrid vocabulary alignment method that uses a direct-mapping loss for shared tokens and falls back to ULD's sorting method only for unmatched tokens.
472
 
473
  Our experiments on the Countdown math task confirm GOLD's advantages. We showed that GOLD significantly outperforms the original offline ULD implementation, recovering 60% of the teacher's performance versus ULD's 10%. Furthermore, GOLD proved superior to other on-policy methods, outperforming a supervised fine-tuning baseline by 15% and a GRPO baseline by 2x. Even in the difficult cross-tokenizer scenario, GOLD still outperformed GRPO by 20%.
474
 
@@ -476,6 +476,6 @@ These findings demonstrate that GOLD is a powerful and flexible technique for mo
476
 
477
  [^f1]: The full GKD loss is then formally defined as: $$\mathcal{L}_{GKD} := (1-\lambda) \mathbb{E}_{(x,y)\sim (X,Y)}[\mathcal{D}_{JSD(\beta)}] + \lambda \mathbb{E}_{x \sim X}[\mathbb{E}_{y \sim p_{S}(.|x)}[\mathcal{D}_{JSD(\beta)}]].$$
478
 
479
- [^f2]: The details of why we can merge the softmax by adding the log probabilities from the merged positions in the sequence. $$P(\text{``Hugging Face"} \mid \text{``<think>"}) = P(\text{``Hugging"} \mid \text{``<think>"}) \times P(\text{``Face"} \mid \text{``<think> Hugging"})$$ and $$\log P(\text{``Hugging Face"} \mid \text{"<think>"}) = \log P(\text{``Hugging"} \mid \text{``<think>"}) + \log P(\text{``Face"} \mid \text{``<think> Hugging"})$$
480
 
481
  [^f3]: The $\beta$ parameter then controls the generalized Jensen-Shannon divergence between the student (S) and teacher (T) distributions, calculated via the following loss summed over the sequence and averaged over the batch: $$\mathcal{D}_{\text{JSD}(\beta)}(p_S, p_T) = \beta \cdot D_{\text{KL}}(p_S \| \pi) + (1-\beta) \cdot D_{\text{KL}}(p_T \| \pi)$$ where $\pi = \beta \cdot p_S + (1-\beta) \cdot p_T$.
 
165
 
166
  This alignment error worsens as tokenization differences increase because a single mismatch at the start of a sequence can propagate and create a cascading semantic error throughout the text.
167
 
168
+ Instead of truncating, our method identifies the token merges required to equalise the sequence lengths for both tokenizers. We then merge the probabilities at the corresponding positions by multiplying the marginal distribution by scalar conditional probabilities of the actual continuation tokens.
169
 
170
+ We perform the token merge through scalar multiplication to leverage the autoregressive nature of LLM sampling. Following the example in Figure 3, we want to merge "Hugging" and " Face" into one token for the sequence in blue. Using the conditional probabilities and the product rule[^f2], we can merge the probabilities and guarantee sequence alignment regardless of tokenizer discrepancies in the sequence dimension.
171
 
172
  <Image
173
  src={sequenceAlignment}
 
439
  --temperature 1.0 \
440
  --top_p 0.95 \
441
  --top_k 0 \
442
+ --max_completion_length 2048 \
443
+ --max_length 2560 \
444
  --lmbda 0.25 \
445
  --beta 0.0 \
446
  --use_uld_loss \
 
460
  --trackio_project Qwen3-4B-GKD-Tulu \
461
  --seed 42 \
462
  --warmup_ratio 0.05 \
463
+ --lr_scheduler_type cosine_with_min_lr
464
  ```
465
  </Accordion>
466
 
 
468
 
469
  In this post, we introduced General On-Policy Logit Distillation (GOLD), a new method that enables effective on-policy knowledge distillation between models, even when the teacher and student do not share the same tokenizer vocabulary. This overcomes a significant limitation of existing on-policy methods like GKD, which require matched tokenizers.
470
 
471
+ GOLD builds upon the offline ULD method but extends it to the on-policy setting and, critically, addresses its two main weaknesses. First, we replace ULD's naive sequence truncation with a token-merging strategy that multiplies marginal distributions by scalar conditional probabilities. Second, we implement a hybrid vocabulary alignment method that uses a direct-mapping loss for shared tokens and falls back to ULD's sorting method only for unmatched tokens.
472
 
473
  Our experiments on the Countdown math task confirm GOLD's advantages. We showed that GOLD significantly outperforms the original offline ULD implementation, recovering 60% of the teacher's performance versus ULD's 10%. Furthermore, GOLD proved superior to other on-policy methods, outperforming a supervised fine-tuning baseline by 15% and a GRPO baseline by 2x. Even in the difficult cross-tokenizer scenario, GOLD still outperformed GRPO by 20%.
474
 
 
476
 
477
  [^f1]: The full GKD loss is then formally defined as: $$\mathcal{L}_{GKD} := (1-\lambda) \mathbb{E}_{(x,y)\sim (X,Y)}[\mathcal{D}_{JSD(\beta)}] + \lambda \mathbb{E}_{x \sim X}[\mathbb{E}_{y \sim p_{S}(.|x)}[\mathcal{D}_{JSD(\beta)}]].$$
478
 
479
+ [^f2]: The details of why we can merge the probabilities using the chain rule. For the merged distribution at position i: $$P_{\text{merged}}(y) = P(y \mid x) \times P(\text{token}_1 \mid x) \times P(\text{token}_2 \mid \text{token}_1, x) \times \dots$$ This correctly computes the joint probability of the actual generated sequence while providing a reasonable approximation for counterfactual tokens.
480
 
481
  [^f3]: The $\beta$ parameter then controls the generalized Jensen-Shannon divergence between the student (S) and teacher (T) distributions, calculated via the following loss summed over the sequence and averaged over the batch: $$\mathcal{D}_{\text{JSD}(\beta)}(p_S, p_T) = \beta \cdot D_{\text{KL}}(p_S \| \pi) + (1-\beta) \cdot D_{\text{KL}}(p_T \| \pi)$$ where $\pi = \beta \cdot p_S + (1-\beta) \cdot p_T$.