zwave commited on
Commit
4fa0566
·
verified ·
1 Parent(s): 58dd072

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +123 -1
README.md CHANGED
@@ -1,3 +1,125 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: mit
3
+ tags:
4
+ - language-model
5
+ - multi-token-prediction
6
+ - push-forward-language-model
7
+ - text-generation
8
+ - distillation
9
+ datasets:
10
+ - lm1b
11
+ - openwebtext
12
+ arxiv: "2606.10820"
13
  ---
14
+
15
+ # K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2606.10820"><img src="https://img.shields.io/badge/arXiv-2606.10820-b31b1b.svg" alt="arXiv"></a>
19
+ <a href="https://github.com/Tangzw2020/K-Forcing"><img src="https://img.shields.io/badge/GitHub-Code-blue?logo=github" alt="GitHub"></a>
20
+ </p>
21
+
22
+ ## Overview
23
+
24
+ K-Forcing distills an autoregressive (AR) language model into a **push-forward language model (PFLM)** that generates **k tokens in one forward pass**. It maps k independent uniform noise variables to k future tokens jointly via an inverse-CDF construction, enabling fixed-length multi-token decoding that is fully compatible with standard KV-cache batch serving.
25
+
26
+ **Key results**: ~2.4–3.5× batch-serving throughput speedup at modest quality degradation on LM1B and OpenWebText with ~100M-param Transformers.
27
+
28
+ ## Checkpoints
29
+
30
+ This repository contains four checkpoints:
31
+
32
+ | File | Model | Dataset | Parameters | Description |
33
+ |------|-------|---------|------------|-------------|
34
+ | `ar_openwebtxt.ckpt` | AR | OpenWebText | ~100M | Autoregressive teacher model (GPT-2 tokenizer, seq_len=1024) |
35
+ | `ar_best_lm1b.ckpt` | AR | LM1B | ~100M | Autoregressive teacher model (custom tokenizer, seq_len=128) |
36
+ | `pflm_owt_k4.ckpt` | PFLM (k=4) | OpenWebText | ~100M | Push-forward LM, decodes 4 tokens per forward pass |
37
+ | `pflm_lm1b_k4.ckpt` | PFLM (k=4) | LM1B | ~100M | Push-forward LM, decodes 4 tokens per forward pass |
38
+
39
+ All models share a 12-layer causal Transformer backbone (768 hidden dim, 12 heads), following the architecture from [MDLM](https://arxiv.org/abs/2406.07524) (Sahoo et al., 2024).
40
+
41
+ ## Download
42
+
43
+ ```python
44
+ from huggingface_hub import hf_hub_download
45
+
46
+ # Download a specific checkpoint
47
+ ckpt_path = hf_hub_download(
48
+ repo_id="zwave/K-Forcing",
49
+ filename="pflm_owt_k4.ckpt", # or: ar_openwebtxt.ckpt, ar_best_lm1b.ckpt, pflm_lm1b_k4.ckpt
50
+ )
51
+ ```
52
+
53
+ Or download all checkpoints at once:
54
+ ```python
55
+ from huggingface_hub import snapshot_download
56
+
57
+ snapshot_download(repo_id="zwave/K-Forcing", local_dir="./checkpoints")
58
+ ```
59
+
60
+ Or via CLI:
61
+ ```bash
62
+ huggingface-cli download zwave/K-Forcing --local-dir ./checkpoints
63
+ ```
64
+
65
+ ## Usage
66
+
67
+ Clone the [K-Forcing repository](https://github.com/Tangzw2020/K-Forcing) and follow setup instructions there:
68
+
69
+ ```bash
70
+ git clone https://github.com/Tangzw2020/K-Forcing.git
71
+ cd K-Forcing
72
+
73
+ # Setup environment
74
+ mkdir -p wheels
75
+ wget -P wheels https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
76
+ uv sync
77
+ ```
78
+
79
+ ### AR Inference
80
+
81
+ ```bash
82
+ python batch_inference_with_prefix.py \
83
+ --model ar --task owt \
84
+ --ckpt_path ./checkpoints/ar_openwebtxt.ckpt \
85
+ --prefix_file assets/prefix_owt_examples.jsonl \
86
+ --batch_size 4 --n_per_prefix 1
87
+ ```
88
+
89
+ ### PFLM Inference (K=2 tokens per forward pass)
90
+
91
+ ```bash
92
+ python batch_inference_with_prefix.py \
93
+ --model pflm --task owt \
94
+ --ckpt_path ./checkpoints/pflm_owt_k4.ckpt \
95
+ --prefix_file assets/prefix_owt_examples.jsonl \
96
+ --batch_size 4 --n_per_prefix 1 --K 2 --freq_penalty 0.3
97
+ ```
98
+
99
+ The PFLM checkpoint trained with k=4 supports inference with any K ≤ 4.
100
+
101
+ ## Architecture
102
+
103
+ - **Backbone**: 12-layer causal Transformer (~100M params), 768 hidden dim, 12 heads
104
+ - **Noise encoder**: sinusoidal + MLP, encodes each Uniform(0,1) noise variable into a token embedding
105
+ - **Fully causal design**: noise tokens attend causally — each zⱼ sees context + z₁..zⱼ
106
+ - **Shared prediction head**: same linear head as AR, applied at each noise-token position
107
+ - **Training**: progressive self-forcing distillation (AR → k=1 → k=2 → k=4)
108
+
109
+ ## Citation
110
+
111
+ ```bibtex
112
+ @misc{tang2026kforcingjointnextktokendecoding,
113
+ title={K-Forcing: Joint Next-K-Token Decoding via Push-Forward Language Modeling},
114
+ author={Zhiwei Tang and Yuanyu He and Yizheng Han and Wangbo Zhao and Jiasheng Tang and Fan Wang and Bohan Zhuang},
115
+ year={2026},
116
+ eprint={2606.10820},
117
+ archivePrefix={arXiv},
118
+ primaryClass={cs.LG},
119
+ url={https://arxiv.org/abs/2606.10820},
120
+ }
121
+ ```
122
+
123
+ ## License
124
+
125
+ This project is licensed under the MIT License.