fix token merging discription
Browse files
app/src/content/article.mdx
CHANGED
|
@@ -165,9 +165,9 @@ The first limitation we address is ULD's sequence alignment, which simply trunca
|
|
| 165 |
|
| 166 |
This alignment error worsens as tokenization differences increase because a single mismatch at the start of a sequence can propagate and create a cascading semantic error throughout the text.
|
| 167 |
|
| 168 |
-
Instead of truncating, our method identifies the token merges required to equalise the sequence lengths for both tokenizers. We then merge the
|
| 169 |
|
| 170 |
-
We perform the token merge through
|
| 171 |
|
| 172 |
<Image
|
| 173 |
src={sequenceAlignment}
|
|
@@ -439,8 +439,8 @@ accelerate launch \
|
|
| 439 |
--temperature 1.0 \
|
| 440 |
--top_p 0.95 \
|
| 441 |
--top_k 0 \
|
| 442 |
-
--
|
| 443 |
-
--
|
| 444 |
--lmbda 0.25 \
|
| 445 |
--beta 0.0 \
|
| 446 |
--use_uld_loss \
|
|
@@ -460,7 +460,7 @@ accelerate launch \
|
|
| 460 |
--trackio_project Qwen3-4B-GKD-Tulu \
|
| 461 |
--seed 42 \
|
| 462 |
--warmup_ratio 0.05 \
|
| 463 |
-
--lr_scheduler_type cosine_with_min_lr
|
| 464 |
```
|
| 465 |
</Accordion>
|
| 466 |
|
|
@@ -468,7 +468,7 @@ accelerate launch \
|
|
| 468 |
|
| 469 |
In this post, we introduced General On-Policy Logit Distillation (GOLD), a new method that enables effective on-policy knowledge distillation between models, even when the teacher and student do not share the same tokenizer vocabulary. This overcomes a significant limitation of existing on-policy methods like GKD, which require matched tokenizers.
|
| 470 |
|
| 471 |
-
GOLD builds upon the offline ULD method but extends it to the on-policy setting and, critically, addresses its two main weaknesses. First, we replace ULD's naive sequence truncation with a token-merging strategy that
|
| 472 |
|
| 473 |
Our experiments on the Countdown math task confirm GOLD's advantages. We showed that GOLD significantly outperforms the original offline ULD implementation, recovering 60% of the teacher's performance versus ULD's 10%. Furthermore, GOLD proved superior to other on-policy methods, outperforming a supervised fine-tuning baseline by 15% and a GRPO baseline by 2x. Even in the difficult cross-tokenizer scenario, GOLD still outperformed GRPO by 20%.
|
| 474 |
|
|
@@ -476,6 +476,6 @@ These findings demonstrate that GOLD is a powerful and flexible technique for mo
|
|
| 476 |
|
| 477 |
[^f1]: The full GKD loss is then formally defined as: $$\mathcal{L}_{GKD} := (1-\lambda) \mathbb{E}_{(x,y)\sim (X,Y)}[\mathcal{D}_{JSD(\beta)}] + \lambda \mathbb{E}_{x \sim X}[\mathbb{E}_{y \sim p_{S}(.|x)}[\mathcal{D}_{JSD(\beta)}]].$$
|
| 478 |
|
| 479 |
-
[^f2]: The details of why we can merge the
|
| 480 |
|
| 481 |
[^f3]: The $\beta$ parameter then controls the generalized Jensen-Shannon divergence between the student (S) and teacher (T) distributions, calculated via the following loss summed over the sequence and averaged over the batch: $$\mathcal{D}_{\text{JSD}(\beta)}(p_S, p_T) = \beta \cdot D_{\text{KL}}(p_S \| \pi) + (1-\beta) \cdot D_{\text{KL}}(p_T \| \pi)$$ where $\pi = \beta \cdot p_S + (1-\beta) \cdot p_T$.
|
|
|
|
| 165 |
|
| 166 |
This alignment error worsens as tokenization differences increase because a single mismatch at the start of a sequence can propagate and create a cascading semantic error throughout the text.
|
| 167 |
|
| 168 |
+
Instead of truncating, our method identifies the token merges required to equalise the sequence lengths for both tokenizers. We then merge the probabilities at the corresponding positions by multiplying the marginal distribution by scalar conditional probabilities of the actual continuation tokens.
|
| 169 |
|
| 170 |
+
We perform the token merge through scalar multiplication to leverage the autoregressive nature of LLM sampling. Following the example in Figure 3, we want to merge "Hugging" and " Face" into one token for the sequence in blue. Using the conditional probabilities and the product rule[^f2], we can merge the probabilities and guarantee sequence alignment regardless of tokenizer discrepancies in the sequence dimension.
|
| 171 |
|
| 172 |
<Image
|
| 173 |
src={sequenceAlignment}
|
|
|
|
| 439 |
--temperature 1.0 \
|
| 440 |
--top_p 0.95 \
|
| 441 |
--top_k 0 \
|
| 442 |
+
--max_completion_length 2048 \
|
| 443 |
+
--max_length 2560 \
|
| 444 |
--lmbda 0.25 \
|
| 445 |
--beta 0.0 \
|
| 446 |
--use_uld_loss \
|
|
|
|
| 460 |
--trackio_project Qwen3-4B-GKD-Tulu \
|
| 461 |
--seed 42 \
|
| 462 |
--warmup_ratio 0.05 \
|
| 463 |
+
--lr_scheduler_type cosine_with_min_lr
|
| 464 |
```
|
| 465 |
</Accordion>
|
| 466 |
|
|
|
|
| 468 |
|
| 469 |
In this post, we introduced General On-Policy Logit Distillation (GOLD), a new method that enables effective on-policy knowledge distillation between models, even when the teacher and student do not share the same tokenizer vocabulary. This overcomes a significant limitation of existing on-policy methods like GKD, which require matched tokenizers.
|
| 470 |
|
| 471 |
+
GOLD builds upon the offline ULD method but extends it to the on-policy setting and, critically, addresses its two main weaknesses. First, we replace ULD's naive sequence truncation with a token-merging strategy that multiplies marginal distributions by scalar conditional probabilities. Second, we implement a hybrid vocabulary alignment method that uses a direct-mapping loss for shared tokens and falls back to ULD's sorting method only for unmatched tokens.
|
| 472 |
|
| 473 |
Our experiments on the Countdown math task confirm GOLD's advantages. We showed that GOLD significantly outperforms the original offline ULD implementation, recovering 60% of the teacher's performance versus ULD's 10%. Furthermore, GOLD proved superior to other on-policy methods, outperforming a supervised fine-tuning baseline by 15% and a GRPO baseline by 2x. Even in the difficult cross-tokenizer scenario, GOLD still outperformed GRPO by 20%.
|
| 474 |
|
|
|
|
| 476 |
|
| 477 |
[^f1]: The full GKD loss is then formally defined as: $$\mathcal{L}_{GKD} := (1-\lambda) \mathbb{E}_{(x,y)\sim (X,Y)}[\mathcal{D}_{JSD(\beta)}] + \lambda \mathbb{E}_{x \sim X}[\mathbb{E}_{y \sim p_{S}(.|x)}[\mathcal{D}_{JSD(\beta)}]].$$
|
| 478 |
|
| 479 |
+
[^f2]: The details of why we can merge the probabilities using the chain rule. For the merged distribution at position i: $$P_{\text{merged}}(y) = P(y \mid x) \times P(\text{token}_1 \mid x) \times P(\text{token}_2 \mid \text{token}_1, x) \times \dots$$ This correctly computes the joint probability of the actual generated sequence while providing a reasonable approximation for counterfactual tokens.
|
| 480 |
|
| 481 |
[^f3]: The $\beta$ parameter then controls the generalized Jensen-Shannon divergence between the student (S) and teacher (T) distributions, calculated via the following loss summed over the sequence and averaged over the batch: $$\mathcal{D}_{\text{JSD}(\beta)}(p_S, p_T) = \beta \cdot D_{\text{KL}}(p_S \| \pi) + (1-\beta) \cdot D_{\text{KL}}(p_T \| \pi)$$ where $\pi = \beta \cdot p_S + (1-\beta) \cdot p_T$.
|