Text Generation
Transformers
Safetensors
qwen3
triton
kernel-generation
reinforcement-learning
code
conversational
text-generation-inference
File size: 9,751 Bytes
990fe42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
---
library_name: transformers
pipeline_tag: text-generation
base_model: Qwen/Qwen3-8B-Base
tags:
  - qwen3
  - triton
  - kernel-generation
  - reinforcement-learning
  - code
datasets:
  - hkust-nlp/drkernel-coldstart-8k
  - hkust-nlp/drkernel-rl-data
  - hkust-nlp/drkernel-validation-data
---

# DR.Kernel-8B

[![Model](https://img.shields.io/badge/🤗%20Model-hkust--nlp/drkernel--8b-yellow)](https://huggingface.co/hkust-nlp/drkernel-8b)
[![Paper](https://img.shields.io/badge/arXiv-2602.05885-b31b1b)](https://arxiv.org/abs/2602.05885)

`hkust-nlp/drkernel-8b` is a Qwen3-8B-based model specialized for GPU kernel generation and optimization (especially Triton) in the DR.Kernel framework.

It is trained for iterative optimization with execution feedback from KernelGYM, rather than single-shot code generation only.

## Model Summary

- Model type: `Qwen3ForCausalLM`
- Parameter count: `8,190,735,360` (from `model.safetensors.index.json`)
- Weight dtype: BF16
- Base model family: Qwen3-8B
- Main capability: generate and iteratively refine optimized `ModelNew` kernel implementations from PyTorch reference tasks


## Training Recipe (DR.Kernel)

The 8B model follows the same two-stage DR.Kernel pipeline:

1. Cold-start SFT
   - Dataset: `hkust-nlp/drkernel-coldstart-8k`
   - Multi-turn trajectory warm-up for kernel generation/refinement
2. Multi-turn RL
   - Train dataset: `hkust-nlp/drkernel-rl-data`
   - Validation dataset: `hkust-nlp/drkernel-validation-data` (KernelBench Level 2 validation split)
   - Core methods: TRLOO + MRS + PR + PRS
   - Execution environment: KernelGYM with compilation/correctness/performance/profiling feedback

Related training scripts in this repo:

- `drkernel/kernel/scripts/sft/8b-coldstart.sh`
- `drkernel/kernel/scripts/rl/8b_trloo_mrs_pr_prs.sh`

## Intended Use

- Kernel generation research and benchmarking
- Triton kernel optimization with iterative feedback
- Multi-turn agentic code refinement under execution-based reward

## Not Intended Use

- Safety-critical production deployment without additional verification
- General-purpose coding assistant use where kernel-evaluation feedback is unavailable

## Prompting Format

The model is trained with kernel-optimization prompts that:

- Provide a PyTorch reference architecture (`Model`, `get_inputs`, `get_init_inputs`)
- Require returning an optimized `ModelNew`
- In multi-turn settings, append server feedback and request iterative improvement

For best behavior, keep the same task style as DR.Kernel datasets and use chat-format messages.

## Quick Start (Transformers)

Use the same fixed 1-shot first-turn prompt template as DR.Kernel data (recommended):

````python
import textwrap
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "hkust-nlp/drkernel-8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

ref_code = textwrap.dedent(
    """
    import torch
    import torch.nn as nn

    class Model(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            x = torch.abs(x)
            x = x - 1.0
            return x

    def get_inputs():
        return [torch.randn(64, 128)]

    def get_init_inputs():
        return []
    """
).strip()

example_ref_code = textwrap.dedent(
    """
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self) -> None:
            super().__init__()

        def forward(self, a, b):
            return a + b

    def get_inputs():
        # randomly generate input tensors based on the model architecture
        a = torch.randn(1, 128).cuda()
        b = torch.randn(1, 128).cuda()
        return [a, b]

    def get_init_inputs():
        # randomly generate tensors required for initialization based on the model architecture
        return []
    """
).strip()

example_kernel_code = textwrap.dedent(
    '''
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import triton
    import triton.language as tl

    @triton.jit
    def add_kernel(
        x_ptr,  # Pointer to first input
        y_ptr,  # Pointer to second input
        out_ptr,  # Pointer to output
        n_elements,  # Total number of elements in input/output
        BLOCK_SIZE: tl.constexpr,
    ):
        # Each program handles a contiguous block of data of size BLOCK_SIZE
        block_start = tl.program_id(0) * BLOCK_SIZE
        # Create a range of offsets [0..BLOCK_SIZE-1]
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        # Mask to ensure we don't go out of bounds
        mask = offsets < n_elements
        # Load input values
        x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
        y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
        # Perform the elementwise addition
        out = x + y
        # Store the result
        tl.store(out_ptr + offsets, out, mask=mask)

    def triton_add(x: torch.Tensor, y: torch.Tensor):
        """
        This function wraps the Triton kernel call. It:
          1. Ensures the inputs are contiguous on GPU.
          2. Calculates the grid (blocks) needed.
          3. Launches the Triton kernel.
        """
        assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."
        x = x.contiguous()
        y = y.contiguous()

        # Prepare output tensor
        out = torch.empty_like(x)

        # Number of elements in the tensor
        n_elements = x.numel()
        BLOCK_SIZE = 128  # Tunable parameter for block size

        # Determine the number of blocks needed
        grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)

        # Launch the Triton kernel
        add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
        return out

    class ModelNew(nn.Module):
        def __init__(self) -> None:
            super().__init__()

        def forward(self, a, b):
            # Instead of "return a + b", call our Triton-based addition
            return triton_add(a, b)
    '''
).strip()

prompt_template = textwrap.dedent(
    """\
    You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.

    You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.

    Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is:

    ```python
    {example_ref_code}
    ```

    The example new arch with custom Triton kernels looks like this:

    ```python
    {example_kernel_code}
    ```

    You are given the following architecture:
    ```python
    {ref_code}
    ```

    Optimize the architecture named Model with custom Triton operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Let's think step by step.
    """
).strip()

prompt = prompt_template.format(
    example_ref_code=example_ref_code,
    example_kernel_code=example_kernel_code,
    ref_code=ref_code,
)
messages = [{"role": "user", "content": prompt}]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
).to(model.device)

with torch.no_grad():
    outputs = model.generate(
        inputs,
        max_new_tokens=2048,
        do_sample=True,
        temperature=1.0,
        top_p=1.0,
    )

# Only print newly generated tokens
print(tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=False))
````

## Evaluation

Use KernelGYM-based evaluation scripts in this repo:

- `drkernel/kernel/scripts/eval/drkernel-14b-maxturns3.sh` (set model path to `hkust-nlp/drkernel-8b`)
- `drkernel/kernel/scripts/eval/grading_common.sh` for custom evaluation runs

Validation data:

- `hkust-nlp/drkernel-validation-data` (KernelBench Level 2 validation tasks)

## Data and Attribution

- Query/task source includes:
  - [ByteDance-Seed/cudaLLM-data](https://huggingface.co/datasets/ByteDance-Seed/cudaLLM-data)
- SFT cold-start trajectories:
  - [hkust-nlp/drkernel-coldstart-8k](https://huggingface.co/datasets/hkust-nlp/drkernel-coldstart-8k)
- RL train data:
  - [hkust-nlp/drkernel-rl-data](https://huggingface.co/datasets/hkust-nlp/drkernel-rl-data)
- Validation/eval data:
  - [hkust-nlp/drkernel-validation-data](https://huggingface.co/datasets/hkust-nlp/drkernel-validation-data)
- Benchmark source:
  - [KernelBench](https://github.com/ScalingIntelligence/KernelBench)

Please acknowledge original dataset/benchmark authors when using this model.

## Related Resources

- Paper: [Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations](https://arxiv.org/abs/2602.05885)
- Codebase: [KernelGYM](https://github.com/hkust-nlp/KernelGYM)
- Training docs: `drkernel/README.md`

## Citation

```bibtex
@article{liuetal2026,
  title={Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations},
  author={Wei Liu, Jiawei Xu, Yingru Li, Longtao Zheng, Tianjian Li, Qian Liu, Junxian He},
  journal={arXiv:2602.05885},
  year={2026}
}
```