PeterV09 commited on
Commit
e69ea4f
·
verified ·
1 Parent(s): 978cf0e

upload ckpt

Browse files
Files changed (1) hide show
  1. README.md +280 -0
README.md ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ pipeline_tag: text-generation
4
+ base_model: Qwen/Qwen3-14B-Base
5
+ tags:
6
+ - qwen3
7
+ - triton
8
+ - kernel-generation
9
+ - supervised-finetuning
10
+ - cold-start
11
+ - code
12
+ datasets:
13
+ - hkust-nlp/drkernel-coldstart-8k
14
+ ---
15
+
16
+ # DR.Kernel-14B-ColdStart
17
+
18
+ [![Model](https://img.shields.io/badge/🤗%20Model-hkust--nlp/drkernel--14b--coldstart-yellow)](https://huggingface.co/hkust-nlp/drkernel-14b-coldstart)
19
+ [![Paper](https://img.shields.io/badge/arXiv-2602.05885-b31b1b)](https://arxiv.org/abs/2602.05885)
20
+
21
+ `hkust-nlp/drkernel-14b-coldstart` is the **cold-start SFT checkpoint** for DR.Kernel.
22
+
23
+ This model is trained on multi-turn SFT data only, and is intended as the initialization checkpoint before RL (TRLOO/MRS/PR/PRS).
24
+
25
+ ## Model Summary
26
+
27
+ - Model type: `Qwen3ForCausalLM`
28
+ - Base model family: Qwen3-14B
29
+ - Stage: cold-start supervised fine-tuning (before RL)
30
+ - Main capability: structured kernel-optimization responses (`Model` -> `ModelNew`) with DR.Kernel prompt format
31
+
32
+ ## Training Stage
33
+
34
+ This checkpoint corresponds to:
35
+
36
+ 1. Cold-start SFT only
37
+ - Dataset: `hkust-nlp/drkernel-coldstart-8k`
38
+ - Multi-turn trajectories to teach kernel-generation/refinement behavior
39
+
40
+ Not included in this checkpoint:
41
+
42
+ - RL stage (TRLOO + MRS + PR + PRS)
43
+ - RL reward shaping / rejection sampling updates
44
+
45
+ Related script:
46
+
47
+ - `drkernel/kernel/scripts/sft/14b-coldstart.sh`
48
+
49
+ ## Intended Use
50
+
51
+ - As an initialization checkpoint for DR.Kernel RL training
52
+ - As a strong SFT baseline for kernel generation
53
+ - For ablations comparing cold-start vs post-RL checkpoints
54
+
55
+ ## Not Intended Use
56
+
57
+ - Final performance claims for DR.Kernel RL results
58
+ - Safety-critical production deployment without additional verification
59
+
60
+ ## Quick Start (Transformers)
61
+
62
+ Use the same fixed 1-shot first-turn prompt template as DR.Kernel data (recommended):
63
+
64
+ ````python
65
+ import textwrap
66
+ import torch
67
+ from transformers import AutoModelForCausalLM, AutoTokenizer
68
+
69
+ model_id = "hkust-nlp/drkernel-14b-coldstart"
70
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ model_id,
73
+ torch_dtype=torch.bfloat16,
74
+ device_map="auto",
75
+ trust_remote_code=True,
76
+ )
77
+
78
+ ref_code = textwrap.dedent(
79
+ """
80
+ import torch
81
+ import torch.nn as nn
82
+
83
+ class Model(nn.Module):
84
+ def __init__(self):
85
+ super().__init__()
86
+
87
+ def forward(self, x):
88
+ x = torch.abs(x)
89
+ x = x - 1.0
90
+ return x
91
+
92
+ def get_inputs():
93
+ return [torch.randn(64, 128)]
94
+
95
+ def get_init_inputs():
96
+ return []
97
+ """
98
+ ).strip()
99
+
100
+ example_ref_code = textwrap.dedent(
101
+ """
102
+ import torch
103
+ import torch.nn as nn
104
+ import torch.nn.functional as F
105
+
106
+ class Model(nn.Module):
107
+ def __init__(self) -> None:
108
+ super().__init__()
109
+
110
+ def forward(self, a, b):
111
+ return a + b
112
+
113
+ def get_inputs():
114
+ # randomly generate input tensors based on the model architecture
115
+ a = torch.randn(1, 128).cuda()
116
+ b = torch.randn(1, 128).cuda()
117
+ return [a, b]
118
+
119
+ def get_init_inputs():
120
+ # randomly generate tensors required for initialization based on the model architecture
121
+ return []
122
+ """
123
+ ).strip()
124
+
125
+ example_kernel_code = textwrap.dedent(
126
+ '''
127
+ import torch
128
+ import torch.nn as nn
129
+ import torch.nn.functional as F
130
+ import triton
131
+ import triton.language as tl
132
+
133
+ @triton.jit
134
+ def add_kernel(
135
+ x_ptr, # Pointer to first input
136
+ y_ptr, # Pointer to second input
137
+ out_ptr, # Pointer to output
138
+ n_elements, # Total number of elements in input/output
139
+ BLOCK_SIZE: tl.constexpr,
140
+ ):
141
+ # Each program handles a contiguous block of data of size BLOCK_SIZE
142
+ block_start = tl.program_id(0) * BLOCK_SIZE
143
+ # Create a range of offsets [0..BLOCK_SIZE-1]
144
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
145
+ # Mask to ensure we don't go out of bounds
146
+ mask = offsets < n_elements
147
+ # Load input values
148
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
149
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
150
+ # Perform the elementwise addition
151
+ out = x + y
152
+ # Store the result
153
+ tl.store(out_ptr + offsets, out, mask=mask)
154
+
155
+ def triton_add(x: torch.Tensor, y: torch.Tensor):
156
+ """
157
+ This function wraps the Triton kernel call. It:
158
+ 1. Ensures the inputs are contiguous on GPU.
159
+ 2. Calculates the grid (blocks) needed.
160
+ 3. Launches the Triton kernel.
161
+ """
162
+ assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."
163
+ x = x.contiguous()
164
+ y = y.contiguous()
165
+
166
+ # Prepare output tensor
167
+ out = torch.empty_like(x)
168
+
169
+ # Number of elements in the tensor
170
+ n_elements = x.numel()
171
+ BLOCK_SIZE = 128 # Tunable parameter for block size
172
+
173
+ # Determine the number of blocks needed
174
+ grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
175
+
176
+ # Launch the Triton kernel
177
+ add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
178
+ return out
179
+
180
+ class ModelNew(nn.Module):
181
+ def __init__(self) -> None:
182
+ super().__init__()
183
+
184
+ def forward(self, a, b):
185
+ # Instead of "return a + b", call our Triton-based addition
186
+ return triton_add(a, b)
187
+ '''
188
+ ).strip()
189
+
190
+ prompt_template = textwrap.dedent(
191
+ """\
192
+ You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.
193
+
194
+ You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.
195
+
196
+ Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is:
197
+
198
+ ```python
199
+ {example_ref_code}
200
+ ```
201
+
202
+ The example new arch with custom Triton kernels looks like this:
203
+
204
+ ```python
205
+ {example_kernel_code}
206
+ ```
207
+
208
+ You are given the following architecture:
209
+ ```python
210
+ {ref_code}
211
+ ```
212
+
213
+ Optimize the architecture named Model with custom Triton operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Let's think step by step.
214
+ """
215
+ ).strip()
216
+
217
+ prompt = prompt_template.format(
218
+ example_ref_code=example_ref_code,
219
+ example_kernel_code=example_kernel_code,
220
+ ref_code=ref_code,
221
+ )
222
+ messages = [{"role": "user", "content": prompt}]
223
+
224
+ inputs = tokenizer.apply_chat_template(
225
+ messages,
226
+ add_generation_prompt=True,
227
+ return_tensors="pt",
228
+ ).to(model.device)
229
+
230
+ with torch.no_grad():
231
+ outputs = model.generate(
232
+ inputs,
233
+ max_new_tokens=2048,
234
+ do_sample=True,
235
+ temperature=1.0,
236
+ top_p=1.0,
237
+ )
238
+
239
+ # Only print newly generated tokens
240
+ print(tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=False))
241
+ ````
242
+
243
+ ## Continue to RL Training
244
+
245
+ This checkpoint is intended to be fed into RL training:
246
+
247
+ - Script: `drkernel/kernel/scripts/rl/14b_trloo_mrs_pr_prs.sh`
248
+ - Typical model setting: `MODEL_PATH="hkust-nlp/drkernel-14b-coldstart"` (or local path)
249
+ - RL datasets:
250
+ - `hkust-nlp/drkernel-rl-data`
251
+ - `hkust-nlp/drkernel-validation-data`
252
+
253
+ ## Data and Attribution
254
+
255
+ - Cold-start SFT data:
256
+ - [hkust-nlp/drkernel-coldstart-8k](https://huggingface.co/datasets/hkust-nlp/drkernel-coldstart-8k)
257
+ - Query/task source includes:
258
+ - [ByteDance-Seed/cudaLLM-data](https://huggingface.co/datasets/ByteDance-Seed/cudaLLM-data)
259
+ - Benchmark source:
260
+ - [KernelBench](https://github.com/ScalingIntelligence/KernelBench)
261
+
262
+ Please acknowledge original dataset/benchmark authors when using this model.
263
+
264
+ ## Related Resources
265
+
266
+ - Final RL model: [hkust-nlp/drkernel-14b](https://huggingface.co/hkust-nlp/drkernel-14b)
267
+ - Paper: [Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations](https://arxiv.org/abs/2602.05885)
268
+ - Codebase: [KernelGYM](https://github.com/hkust-nlp/KernelGYM)
269
+ - Training docs: `drkernel/README.md`
270
+
271
+ ## Citation
272
+
273
+ ```bibtex
274
+ @article{liuetal2026,
275
+ title={Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations},
276
+ author={Wei Liu, Jiawei Xu, Yingru Li, Longtao Zheng, Tianjian Li, Qian Liu, Junxian He},
277
+ journal={arXiv:2602.05885},
278
+ year={2026}
279
+ }
280
+ ```