--- library_name: transformers pipeline_tag: text-generation base_model: Qwen/Qwen3-8B-Base tags: - qwen3 - triton - kernel-generation - supervised-finetuning - cold-start - code datasets: - hkust-nlp/drkernel-coldstart-8k --- # DR.Kernel-8B-ColdStart [![Model](https://img.shields.io/badge/🤗%20Model-hkust--nlp/drkernel--8b--coldstart-yellow)](https://huggingface.co/hkust-nlp/drkernel-8b-coldstart) [![Paper](https://img.shields.io/badge/arXiv-2602.05885-b31b1b)](https://arxiv.org/abs/2602.05885) `hkust-nlp/drkernel-8b-coldstart` is the **cold-start SFT checkpoint** for DR.Kernel. This model is trained on multi-turn SFT data only, and is intended as the initialization checkpoint before RL (TRLOO/MRS/PR/PRS). ## Model Summary - Model type: `Qwen3ForCausalLM` - Base model family: Qwen3-8B - Stage: cold-start supervised fine-tuning (before RL) - Main capability: structured kernel-optimization responses (`Model` -> `ModelNew`) with DR.Kernel prompt format ## Training Stage This checkpoint corresponds to: 1. Cold-start SFT only - Dataset: `hkust-nlp/drkernel-coldstart-8k` - Multi-turn trajectories to teach kernel-generation/refinement behavior Not included in this checkpoint: - RL stage (TRLOO + MRS + PR + PRS) - RL reward shaping / rejection sampling updates Related script: - `drkernel/kernel/scripts/sft/8b-coldstart.sh` ## Intended Use - As an initialization checkpoint for DR.Kernel RL training - As a strong SFT baseline for kernel generation - For ablations comparing cold-start vs post-RL checkpoints ## Not Intended Use - Final performance claims for DR.Kernel RL results - Safety-critical production deployment without additional verification ## Quick Start (Transformers) Use the same fixed 1-shot first-turn prompt template as DR.Kernel data (recommended): ````python import textwrap import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "hkust-nlp/drkernel-8b-coldstart" tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) ref_code = textwrap.dedent( """ import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super().__init__() def forward(self, x): x = torch.abs(x) x = x - 1.0 return x def get_inputs(): return [torch.randn(64, 128)] def get_init_inputs(): return [] """ ).strip() example_ref_code = textwrap.dedent( """ import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, a, b): return a + b def get_inputs(): # randomly generate input tensors based on the model architecture a = torch.randn(1, 128).cuda() b = torch.randn(1, 128).cuda() return [a, b] def get_init_inputs(): # randomly generate tensors required for initialization based on the model architecture return [] """ ).strip() example_kernel_code = textwrap.dedent( ''' import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl @triton.jit def add_kernel( x_ptr, # Pointer to first input y_ptr, # Pointer to second input out_ptr, # Pointer to output n_elements, # Total number of elements in input/output BLOCK_SIZE: tl.constexpr, ): # Each program handles a contiguous block of data of size BLOCK_SIZE block_start = tl.program_id(0) * BLOCK_SIZE # Create a range of offsets [0..BLOCK_SIZE-1] offsets = block_start + tl.arange(0, BLOCK_SIZE) # Mask to ensure we don't go out of bounds mask = offsets < n_elements # Load input values x = tl.load(x_ptr + offsets, mask=mask, other=0.0) y = tl.load(y_ptr + offsets, mask=mask, other=0.0) # Perform the elementwise addition out = x + y # Store the result tl.store(out_ptr + offsets, out, mask=mask) def triton_add(x: torch.Tensor, y: torch.Tensor): """ This function wraps the Triton kernel call. It: 1. Ensures the inputs are contiguous on GPU. 2. Calculates the grid (blocks) needed. 3. Launches the Triton kernel. """ assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." x = x.contiguous() y = y.contiguous() # Prepare output tensor out = torch.empty_like(x) # Number of elements in the tensor n_elements = x.numel() BLOCK_SIZE = 128 # Tunable parameter for block size # Determine the number of blocks needed grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) # Launch the Triton kernel add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) return out class ModelNew(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, a, b): # Instead of "return a + b", call our Triton-based addition return triton_add(a, b) ''' ).strip() prompt_template = textwrap.dedent( """\ You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination. Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: ```python {example_ref_code} ``` The example new arch with custom Triton kernels looks like this: ```python {example_kernel_code} ``` You are given the following architecture: ```python {ref_code} ``` Optimize the architecture named Model with custom Triton operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Let's think step by step. """ ).strip() prompt = prompt_template.format( example_ref_code=example_ref_code, example_kernel_code=example_kernel_code, ref_code=ref_code, ) messages = [{"role": "user", "content": prompt}] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", ).to(model.device) with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=2048, do_sample=True, temperature=1.0, top_p=1.0, ) # Only print newly generated tokens print(tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=False)) ```` ## Continue to RL Training This checkpoint is intended to be fed into RL training: - Script: `drkernel/kernel/scripts/rl/8b_trloo_mrs_pr_prs.sh` - Typical model setting: `MODEL_PATH="hkust-nlp/drkernel-8b-coldstart"` (or local path) - RL datasets: - `hkust-nlp/drkernel-rl-data` - `hkust-nlp/drkernel-validation-data` ## Data and Attribution - Cold-start SFT data: - [hkust-nlp/drkernel-coldstart-8k](https://huggingface.co/datasets/hkust-nlp/drkernel-coldstart-8k) - Query/task source includes: - [ByteDance-Seed/cudaLLM-data](https://huggingface.co/datasets/ByteDance-Seed/cudaLLM-data) - Benchmark source: - [KernelBench](https://github.com/ScalingIntelligence/KernelBench) Please acknowledge original dataset/benchmark authors when using this model. ## Related Resources - Final RL model: [hkust-nlp/drkernel-8b](https://huggingface.co/hkust-nlp/drkernel-8b) - Paper: [Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations](https://arxiv.org/abs/2602.05885) - Codebase: [KernelGYM](https://github.com/hkust-nlp/KernelGYM) - Training docs: `drkernel/README.md` ## Citation ```bibtex @article{liuetal2026, title={Dr.Kernel: Reinforcement Learning Done Right for Triton Kernel Generations}, author={Wei Liu, Jiawei Xu, Yingru Li, Longtao Zheng, Tianjian Li, Qian Liu, Junxian He}, journal={arXiv:2602.05885}, year={2026} } ```