Infatoshi commited on
Commit
367dc36
·
verified ·
1 Parent(s): 4f641e6

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +51 -0
  2. kernrl_grpo_training.ipynb +621 -0
  3. train_kernrl.py +452 -0
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - openenv
5
+ - cuda
6
+ - triton
7
+ - gpu-kernels
8
+ - reinforcement-learning
9
+ - grpo
10
+ ---
11
+
12
+ # kernrl Training Materials
13
+
14
+ Training resources for the kernrl GPU kernel optimization environment.
15
+
16
+ ## Overview
17
+
18
+ This repository contains:
19
+ - GRPO training notebook for training LLMs to write optimized GPU kernels
20
+ - Example scripts and configurations
21
+
22
+ ## Quick Start
23
+
24
+ ```python
25
+ from trl import GRPOConfig, GRPOTrainer
26
+ from kernrl import kernrl_env, KernelAction
27
+
28
+ # Connect to kernrl environment
29
+ env = kernrl_env(base_url="http://localhost:8000")
30
+
31
+ # Train with GRPO
32
+ trainer = GRPOTrainer(
33
+ model="Qwen/Qwen2.5-Coder-1.5B-Instruct",
34
+ reward_funcs=[reward_compilation, reward_correctness, reward_speedup],
35
+ train_dataset=dataset,
36
+ rollout_func=rollout_func,
37
+ args=GRPOConfig(use_vllm=True, vllm_mode="colocate"),
38
+ )
39
+ trainer.train()
40
+ ```
41
+
42
+ ## Files
43
+
44
+ - `kernrl_grpo_training.ipynb` - Complete GRPO training notebook
45
+ - `train_kernrl.py` - Standalone training script
46
+
47
+ ## Links
48
+
49
+ - [kernrl Environment](https://huggingface.co/spaces/Infatoshi/kernrl)
50
+ - [OpenEnv Repository](https://github.com/meta-pytorch/OpenEnv)
51
+ - [TRL Documentation](https://huggingface.co/docs/trl)
kernrl_grpo_training.ipynb ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "91b7681f",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Training LLMs to Write Fast GPU Kernels with GRPO\n",
9
+ "\n",
10
+ "This notebook demonstrates how to train a language model to write optimized CUDA/Triton\n",
11
+ "kernels using TRL's GRPOTrainer and the kernrl OpenEnv environment.\n",
12
+ "\n",
13
+ "**What is kernrl?**\n",
14
+ "- An RL environment for GPU kernel optimization\n",
15
+ "- Agents receive PyTorch reference implementations\n",
16
+ "- Must write faster CUDA/Triton kernels that produce correct outputs\n",
17
+ "- Rewards based on compilation success, correctness, and speedup\n",
18
+ "\n",
19
+ "**What is GRPO?**\n",
20
+ "- Group Relative Policy Optimization\n",
21
+ "- Efficient RL algorithm for training LLMs\n",
22
+ "- Uses multiple generations per prompt to estimate advantages\n",
23
+ "- Works well with environment-based reward signals"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "id": "1c818c9f",
29
+ "metadata": {},
30
+ "source": [
31
+ "## Installation\n",
32
+ "\n",
33
+ "First, install the required packages:"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "03a24248",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "!pip install torch triton trl transformers accelerate\n",
44
+ "!pip install git+https://github.com/meta-pytorch/OpenEnv.git"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "id": "a6bd7b19",
50
+ "metadata": {},
51
+ "source": [
52
+ "## Setup\n",
53
+ "\n",
54
+ "Import necessary libraries and configure the environment."
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "409d8ec7",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "import torch\n",
65
+ "from datasets import Dataset\n",
66
+ "from transformers import AutoTokenizer\n",
67
+ "from trl import GRPOConfig, GRPOTrainer\n",
68
+ "from trl.experimental.openenv import generate_rollout_completions\n",
69
+ "\n",
70
+ "# Import kernrl environment\n",
71
+ "from kernrl import kernrl_env, KernelAction, KernelObservation"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "1195d838",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "# Configuration\n",
82
+ "MODEL_ID = \"Qwen/Qwen2.5-Coder-1.5B-Instruct\" # Good for code generation\n",
83
+ "ENV_URL = \"http://localhost:8000\" # kernrl server URL\n",
84
+ "\n",
85
+ "# Initialize tokenizer\n",
86
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
87
+ "if tokenizer.pad_token is None:\n",
88
+ " tokenizer.pad_token = tokenizer.eos_token"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "id": "0ba43b24",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Connect to kernrl Environment\n",
97
+ "\n",
98
+ "The kernrl environment evaluates submitted kernels for:\n",
99
+ "1. **Compilation**: Does the code compile?\n",
100
+ "2. **Correctness**: Does output match reference (within tolerance)?\n",
101
+ "3. **Performance**: Is it faster than PyTorch baseline?"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "d72ae756",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "# Connect to the kernrl server\n",
112
+ "# Option 1: Connect to running server\n",
113
+ "env = kernrl_env(base_url=ENV_URL)\n",
114
+ "\n",
115
+ "# Option 2: Load from HuggingFace Hub (requires GPU)\n",
116
+ "# env = kernrl_env.from_hub(\"Infatoshi/kernrl\")\n",
117
+ "\n",
118
+ "# Option 3: Local Docker\n",
119
+ "# env = kernrl_env.from_docker_image(\"kernrl:latest\")\n",
120
+ "\n",
121
+ "# Test the connection\n",
122
+ "obs = env.reset(problem_id=\"L1_23_Softmax\")\n",
123
+ "print(f\"Problem: {obs.problem_id}\")\n",
124
+ "print(f\"GPU: {obs.gpu_info}\")\n",
125
+ "print(f\"Max turns: {obs.max_turns}\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "004905fc",
131
+ "metadata": {},
132
+ "source": [
133
+ "## Reward Functions\n",
134
+ "\n",
135
+ "We define multiple reward signals to guide the model:\n",
136
+ "- **Compilation reward**: +0.1 for successful compilation\n",
137
+ "- **Correctness reward**: +0.3 for matching reference output\n",
138
+ "- **Speedup reward**: Scaled reward for beating baseline performance"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "39237d0e",
145
+ "metadata": {
146
+ "lines_to_next_cell": 1
147
+ },
148
+ "outputs": [],
149
+ "source": [
150
+ "import math\n",
151
+ "\n",
152
+ "def reward_compilation(completions: list[str], **kwargs) -> list[float]:\n",
153
+ " \"\"\"Reward for successful compilation.\"\"\"\n",
154
+ " compilation_success = kwargs.get(\"compilation_success\", [])\n",
155
+ " return [0.1 if success else 0.0 for success in compilation_success]\n",
156
+ "\n",
157
+ "def reward_correctness(completions: list[str], **kwargs) -> list[float]:\n",
158
+ " \"\"\"Reward for correct output.\"\"\"\n",
159
+ " correctness_pass = kwargs.get(\"correctness_pass\", [])\n",
160
+ " return [0.3 if correct else 0.0 for correct in correctness_pass]\n",
161
+ "\n",
162
+ "def reward_speedup(completions: list[str], **kwargs) -> list[float]:\n",
163
+ " \"\"\"Reward scaled by speedup achieved.\"\"\"\n",
164
+ " speedups = kwargs.get(\"speedup\", [])\n",
165
+ " rewards = []\n",
166
+ " for speedup in speedups:\n",
167
+ " if speedup is None or speedup <= 0:\n",
168
+ " rewards.append(0.0)\n",
169
+ " elif speedup <= 1.0:\n",
170
+ " # Below baseline: small penalty\n",
171
+ " rewards.append(-0.1)\n",
172
+ " else:\n",
173
+ " # Above baseline: reward scales with log2(speedup)\n",
174
+ " # 2x speedup = 0.3, 4x = 0.6, 8x = 0.9\n",
175
+ " bonus = min(0.3 * math.log2(speedup), 0.6)\n",
176
+ " rewards.append(0.3 + bonus)\n",
177
+ " return rewards\n",
178
+ "\n",
179
+ "def reward_combined(completions: list[str], **kwargs) -> list[float]:\n",
180
+ " \"\"\"Combined reward from all signals.\"\"\"\n",
181
+ " comp_rewards = reward_compilation(completions, **kwargs)\n",
182
+ " corr_rewards = reward_correctness(completions, **kwargs)\n",
183
+ " speed_rewards = reward_speedup(completions, **kwargs)\n",
184
+ " return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)]"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "id": "53307241",
190
+ "metadata": {},
191
+ "source": [
192
+ "## System Prompt\n",
193
+ "\n",
194
+ "The system prompt provides context about the task and expected output format."
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "id": "21d75bd3",
201
+ "metadata": {
202
+ "lines_to_next_cell": 1
203
+ },
204
+ "outputs": [],
205
+ "source": [
206
+ "SYSTEM_PROMPT = \"\"\"You are an expert GPU kernel engineer specializing in CUDA and Triton.\n",
207
+ "\n",
208
+ "Your task is to optimize PyTorch operations by writing custom GPU kernels.\n",
209
+ "\n",
210
+ "Guidelines:\n",
211
+ "1. Analyze the reference PyTorch implementation carefully\n",
212
+ "2. Identify optimization opportunities (memory access patterns, parallelism, fusion)\n",
213
+ "3. Write a Triton or CUDA kernel that computes the same result\n",
214
+ "4. Ensure numerical correctness (outputs must match within tolerance)\n",
215
+ "\n",
216
+ "Output format:\n",
217
+ "- Provide a complete Python file\n",
218
+ "- Include a Model class with the same interface as the reference\n",
219
+ "- The Model.forward() method should use your optimized kernel\n",
220
+ "- Include all necessary imports (torch, triton, triton.language)\n",
221
+ "\n",
222
+ "Focus on:\n",
223
+ "- Coalesced memory access\n",
224
+ "- Efficient use of shared memory\n",
225
+ "- Minimizing thread divergence\n",
226
+ "- Optimal block/grid dimensions\"\"\""
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "id": "607299ce",
232
+ "metadata": {},
233
+ "source": [
234
+ "## Rollout Function\n",
235
+ "\n",
236
+ "The rollout function generates kernel code and evaluates it in the environment."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "5da951b3",
243
+ "metadata": {
244
+ "lines_to_next_cell": 1
245
+ },
246
+ "outputs": [],
247
+ "source": [
248
+ "def make_prompt(problem_description: str, feedback: str = \"\") -> str:\n",
249
+ " \"\"\"Create the user prompt for the model.\"\"\"\n",
250
+ " prompt = f\"{problem_description}\\n\"\n",
251
+ " if feedback:\n",
252
+ " prompt += f\"\\n## Previous Attempt Feedback\\n{feedback}\\n\"\n",
253
+ " prompt += \"\\nProvide your optimized kernel implementation:\"\n",
254
+ " return prompt\n",
255
+ "\n",
256
+ "def extract_code(completion: str) -> str:\n",
257
+ " \"\"\"Extract code from model completion.\"\"\"\n",
258
+ " # Handle markdown code blocks\n",
259
+ " if \"```python\" in completion:\n",
260
+ " start = completion.find(\"```python\") + 9\n",
261
+ " end = completion.find(\"```\", start)\n",
262
+ " if end > start:\n",
263
+ " return completion[start:end].strip()\n",
264
+ " if \"```\" in completion:\n",
265
+ " start = completion.find(\"```\") + 3\n",
266
+ " end = completion.find(\"```\", start)\n",
267
+ " if end > start:\n",
268
+ " return completion[start:end].strip()\n",
269
+ " # Return as-is if no code blocks\n",
270
+ " return completion.strip()\n",
271
+ "\n",
272
+ "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n",
273
+ " \"\"\"\n",
274
+ " Custom rollout function for kernrl environment.\n",
275
+ "\n",
276
+ " Generates kernel code and evaluates it to get rewards.\n",
277
+ " \"\"\"\n",
278
+ " # Generate completions\n",
279
+ " outputs = generate_rollout_completions(trainer, prompts)\n",
280
+ "\n",
281
+ " completions_text = [\n",
282
+ " tokenizer.decode(out[\"completion_ids\"], skip_special_tokens=True)\n",
283
+ " for out in outputs\n",
284
+ " ]\n",
285
+ "\n",
286
+ " # Evaluate each completion in the environment\n",
287
+ " compilation_success = []\n",
288
+ " correctness_pass = []\n",
289
+ " speedups = []\n",
290
+ "\n",
291
+ " for completion in completions_text:\n",
292
+ " # Reset environment for each evaluation\n",
293
+ " obs = env.reset()\n",
294
+ "\n",
295
+ " # Extract code and submit\n",
296
+ " code = extract_code(completion)\n",
297
+ " action = KernelAction(code=code)\n",
298
+ "\n",
299
+ " try:\n",
300
+ " result = env.step(action)\n",
301
+ " obs = result.observation\n",
302
+ "\n",
303
+ " compilation_success.append(obs.compilation_success)\n",
304
+ " correctness_pass.append(obs.correctness_pass or False)\n",
305
+ " speedups.append(obs.speedup)\n",
306
+ " except Exception as e:\n",
307
+ " print(f\"Evaluation error: {e}\")\n",
308
+ " compilation_success.append(False)\n",
309
+ " correctness_pass.append(False)\n",
310
+ " speedups.append(None)\n",
311
+ "\n",
312
+ " return {\n",
313
+ " \"prompt_ids\": [out[\"prompt_ids\"] for out in outputs],\n",
314
+ " \"completion_ids\": [out[\"completion_ids\"] for out in outputs],\n",
315
+ " \"logprobs\": [out[\"logprobs\"] for out in outputs],\n",
316
+ " # Pass reward signals to reward functions\n",
317
+ " \"compilation_success\": compilation_success,\n",
318
+ " \"correctness_pass\": correctness_pass,\n",
319
+ " \"speedup\": speedups,\n",
320
+ " }"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "id": "dae933f9",
326
+ "metadata": {},
327
+ "source": [
328
+ "## Create Training Dataset\n",
329
+ "\n",
330
+ "We create a dataset from kernrl problems. Each problem becomes a training prompt."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "id": "36c6f196",
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset:\n",
341
+ " \"\"\"Create training dataset from kernrl problems.\"\"\"\n",
342
+ " prompts = []\n",
343
+ " problem_ids = []\n",
344
+ "\n",
345
+ " # Get all problem IDs\n",
346
+ " all_problems = env.list_problems()\n",
347
+ "\n",
348
+ " for problem_id in all_problems:\n",
349
+ " # Filter by level\n",
350
+ " level = int(problem_id.split(\"_\")[0][1:]) # Extract level from \"L1_...\"\n",
351
+ " if level not in levels:\n",
352
+ " continue\n",
353
+ "\n",
354
+ " # Reset to get problem description\n",
355
+ " obs = env.reset(problem_id=problem_id)\n",
356
+ "\n",
357
+ " # Create prompt\n",
358
+ " messages = [\n",
359
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
360
+ " {\"role\": \"user\", \"content\": make_prompt(obs.problem_description)},\n",
361
+ " ]\n",
362
+ " prompt = tokenizer.apply_chat_template(\n",
363
+ " messages,\n",
364
+ " add_generation_prompt=True,\n",
365
+ " tokenize=False,\n",
366
+ " )\n",
367
+ "\n",
368
+ " prompts.append(prompt)\n",
369
+ " problem_ids.append(problem_id)\n",
370
+ "\n",
371
+ " return Dataset.from_dict({\n",
372
+ " \"prompt\": prompts,\n",
373
+ " \"problem_id\": problem_ids,\n",
374
+ " })\n",
375
+ "\n",
376
+ "# Create dataset from Level 1 and 2 problems\n",
377
+ "dataset = create_dataset(env, levels=[1, 2])\n",
378
+ "print(f\"Created dataset with {len(dataset)} problems\")"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "markdown",
383
+ "id": "61dcd8db",
384
+ "metadata": {},
385
+ "source": [
386
+ "## Configure Training\n",
387
+ "\n",
388
+ "Set up GRPOTrainer with our custom rollout function and reward signals."
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": null,
394
+ "id": "6fd1d73d",
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "# Training configuration\n",
399
+ "config = GRPOConfig(\n",
400
+ " output_dir=\"./kernrl_grpo_output\",\n",
401
+ "\n",
402
+ " # vLLM settings\n",
403
+ " use_vllm=True,\n",
404
+ " vllm_mode=\"colocate\", # Use \"server\" mode for multi-GPU\n",
405
+ "\n",
406
+ " # Generation settings\n",
407
+ " num_generations=4, # Generations per prompt\n",
408
+ " max_completion_length=2048, # Kernel code can be long\n",
409
+ " temperature=0.7,\n",
410
+ "\n",
411
+ " # Training settings\n",
412
+ " num_train_epochs=3,\n",
413
+ " per_device_train_batch_size=2,\n",
414
+ " gradient_accumulation_steps=4,\n",
415
+ " learning_rate=1e-5,\n",
416
+ "\n",
417
+ " # Logging\n",
418
+ " logging_steps=10,\n",
419
+ " save_steps=100,\n",
420
+ " report_to=\"wandb\", # Optional: log to Weights & Biases\n",
421
+ ")"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "markdown",
426
+ "id": "36db4292",
427
+ "metadata": {},
428
+ "source": [
429
+ "## Initialize Trainer"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "id": "3058bd91",
436
+ "metadata": {},
437
+ "outputs": [],
438
+ "source": [
439
+ "trainer = GRPOTrainer(\n",
440
+ " model=MODEL_ID,\n",
441
+ " processing_class=tokenizer,\n",
442
+ " reward_funcs=[\n",
443
+ " reward_compilation,\n",
444
+ " reward_correctness,\n",
445
+ " reward_speedup,\n",
446
+ " ],\n",
447
+ " train_dataset=dataset,\n",
448
+ " rollout_func=rollout_func,\n",
449
+ " args=config,\n",
450
+ ")"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "id": "26d3cb0f",
456
+ "metadata": {},
457
+ "source": [
458
+ "## Train!\n",
459
+ "\n",
460
+ "Start the training loop. The model will learn to write faster kernels through\n",
461
+ "environment feedback."
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": null,
467
+ "id": "11157d97",
468
+ "metadata": {
469
+ "lines_to_next_cell": 1
470
+ },
471
+ "outputs": [],
472
+ "source": [
473
+ "# Start training\n",
474
+ "trainer.train()\n",
475
+ "\n",
476
+ "# Save the final model\n",
477
+ "trainer.save_model(\"./kernrl_trained_model\")"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "markdown",
482
+ "id": "4ee87425",
483
+ "metadata": {},
484
+ "source": [
485
+ "## Evaluate the Trained Model\n",
486
+ "\n",
487
+ "Test the trained model on some problems to see how well it learned."
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": null,
493
+ "id": "82ed4e39",
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "def evaluate_model(model_path: str, problem_ids: list[str]) -> dict:\n",
498
+ " \"\"\"Evaluate a trained model on kernel optimization problems.\"\"\"\n",
499
+ " from transformers import AutoModelForCausalLM\n",
500
+ "\n",
501
+ " model = AutoModelForCausalLM.from_pretrained(model_path)\n",
502
+ " model.eval()\n",
503
+ "\n",
504
+ " results = []\n",
505
+ "\n",
506
+ " for problem_id in problem_ids:\n",
507
+ " obs = env.reset(problem_id=problem_id)\n",
508
+ "\n",
509
+ " # Generate kernel code\n",
510
+ " messages = [\n",
511
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
512
+ " {\"role\": \"user\", \"content\": make_prompt(obs.problem_description)},\n",
513
+ " ]\n",
514
+ " prompt = tokenizer.apply_chat_template(\n",
515
+ " messages,\n",
516
+ " add_generation_prompt=True,\n",
517
+ " tokenize=False,\n",
518
+ " )\n",
519
+ "\n",
520
+ " inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
521
+ " with torch.no_grad():\n",
522
+ " outputs = model.generate(\n",
523
+ " **inputs,\n",
524
+ " max_new_tokens=2048,\n",
525
+ " temperature=0.3, # Lower temp for evaluation\n",
526
+ " do_sample=True,\n",
527
+ " )\n",
528
+ "\n",
529
+ " completion = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
530
+ " code = extract_code(completion)\n",
531
+ "\n",
532
+ " # Evaluate\n",
533
+ " result = env.step(KernelAction(code=code))\n",
534
+ " obs = result.observation\n",
535
+ "\n",
536
+ " results.append({\n",
537
+ " \"problem_id\": problem_id,\n",
538
+ " \"compilation\": obs.compilation_success,\n",
539
+ " \"correctness\": obs.correctness_pass,\n",
540
+ " \"speedup\": obs.speedup,\n",
541
+ " })\n",
542
+ "\n",
543
+ " print(f\"{problem_id}: compile={obs.compilation_success}, \"\n",
544
+ " f\"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x\"\n",
545
+ " if obs.speedup else f\"{problem_id}: compile={obs.compilation_success}\")\n",
546
+ "\n",
547
+ " return results\n",
548
+ "\n",
549
+ "# Evaluate on a few problems\n",
550
+ "# eval_results = evaluate_model(\"./kernrl_trained_model\", [\"L1_23_Softmax\", \"L1_26_GELU_\"])"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "markdown",
555
+ "id": "45d94da1",
556
+ "metadata": {},
557
+ "source": [
558
+ "## Running with Server Mode (Multi-GPU)\n",
559
+ "\n",
560
+ "For larger models or faster training, use vLLM in server mode:\n",
561
+ "\n",
562
+ "```bash\n",
563
+ "# Terminal 1: Start vLLM server\n",
564
+ "CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-Coder-7B-Instruct\n",
565
+ "\n",
566
+ "# Terminal 2: Start kernrl environment\n",
567
+ "CUDA_VISIBLE_DEVICES=1 uvicorn kernrl.server.app:app --host 0.0.0.0 --port 8000\n",
568
+ "\n",
569
+ "# Terminal 3: Run training\n",
570
+ "CUDA_VISIBLE_DEVICES=2 python train_kernrl.py --vllm-mode server --vllm-server-url http://localhost:8000\n",
571
+ "```\n",
572
+ "\n",
573
+ "Update the config:\n",
574
+ "```python\n",
575
+ "config = GRPOConfig(\n",
576
+ " use_vllm=True,\n",
577
+ " vllm_mode=\"server\",\n",
578
+ " vllm_server_base_url=\"http://localhost:8000\",\n",
579
+ " ...\n",
580
+ ")\n",
581
+ "```"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "markdown",
586
+ "id": "464e71b0",
587
+ "metadata": {},
588
+ "source": [
589
+ "## Tips for Better Results\n",
590
+ "\n",
591
+ "1. **Start with simpler problems**: Level 1 problems (matmul, softmax) are easier\n",
592
+ "2. **Use code-focused models**: Qwen2.5-Coder, DeepSeek-Coder work well\n",
593
+ "3. **Increase generations**: More generations per prompt = better advantage estimates\n",
594
+ "4. **Multi-turn training**: Let the model iterate based on feedback\n",
595
+ "5. **Curriculum learning**: Start with L1, add harder problems gradually"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "markdown",
600
+ "id": "2a03608e",
601
+ "metadata": {},
602
+ "source": [
603
+ "## Resources\n",
604
+ "\n",
605
+ "- [kernrl HuggingFace Space](https://huggingface.co/spaces/Infatoshi/kernrl)\n",
606
+ "- [OpenEnv Repository](https://github.com/meta-pytorch/OpenEnv)\n",
607
+ "- [TRL Documentation](https://huggingface.co/docs/trl)\n",
608
+ "- [Triton Tutorial](https://triton-lang.org/main/getting-started/tutorials/)"
609
+ ]
610
+ }
611
+ ],
612
+ "metadata": {
613
+ "kernelspec": {
614
+ "display_name": "Python 3",
615
+ "language": "python",
616
+ "name": "python3"
617
+ }
618
+ },
619
+ "nbformat": 4,
620
+ "nbformat_minor": 5
621
+ }
train_kernrl.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: percent
7
+ # kernelspec:
8
+ # display_name: Python 3
9
+ # language: python
10
+ # name: python3
11
+ # ---
12
+
13
+ # %% [markdown]
14
+ # # Training LLMs to Write Fast GPU Kernels with GRPO
15
+ #
16
+ # This notebook demonstrates how to train a language model to write optimized CUDA/Triton
17
+ # kernels using TRL's GRPOTrainer and the kernrl OpenEnv environment.
18
+ #
19
+ # **What is kernrl?**
20
+ # - An RL environment for GPU kernel optimization
21
+ # - Agents receive PyTorch reference implementations
22
+ # - Must write faster CUDA/Triton kernels that produce correct outputs
23
+ # - Rewards based on compilation success, correctness, and speedup
24
+ #
25
+ # **What is GRPO?**
26
+ # - Group Relative Policy Optimization
27
+ # - Efficient RL algorithm for training LLMs
28
+ # - Uses multiple generations per prompt to estimate advantages
29
+ # - Works well with environment-based reward signals
30
+
31
+ # %% [markdown]
32
+ # ## Installation
33
+ #
34
+ # First, install the required packages:
35
+
36
+ # %%
37
+ # !pip install torch triton trl transformers accelerate
38
+ # !pip install git+https://github.com/meta-pytorch/OpenEnv.git
39
+
40
+ # %% [markdown]
41
+ # ## Setup
42
+ #
43
+ # Import necessary libraries and configure the environment.
44
+
45
+ # %%
46
+ import torch
47
+ from datasets import Dataset
48
+ from transformers import AutoTokenizer
49
+ from trl import GRPOConfig, GRPOTrainer
50
+ from trl.experimental.openenv import generate_rollout_completions
51
+
52
+ # Import kernrl environment
53
+ from kernrl import kernrl_env, KernelAction, KernelObservation
54
+
55
+ # %%
56
+ # Configuration
57
+ MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct" # Good for code generation
58
+ ENV_URL = "http://localhost:8000" # kernrl server URL
59
+
60
+ # Initialize tokenizer
61
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
62
+ if tokenizer.pad_token is None:
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ # %% [markdown]
66
+ # ## Connect to kernrl Environment
67
+ #
68
+ # The kernrl environment evaluates submitted kernels for:
69
+ # 1. **Compilation**: Does the code compile?
70
+ # 2. **Correctness**: Does output match reference (within tolerance)?
71
+ # 3. **Performance**: Is it faster than PyTorch baseline?
72
+
73
+ # %%
74
+ # Connect to the kernrl server
75
+ # Option 1: Connect to running server
76
+ env = kernrl_env(base_url=ENV_URL)
77
+
78
+ # Option 2: Load from HuggingFace Hub (requires GPU)
79
+ # env = kernrl_env.from_hub("Infatoshi/kernrl")
80
+
81
+ # Option 3: Local Docker
82
+ # env = kernrl_env.from_docker_image("kernrl:latest")
83
+
84
+ # Test the connection
85
+ obs = env.reset(problem_id="L1_23_Softmax")
86
+ print(f"Problem: {obs.problem_id}")
87
+ print(f"GPU: {obs.gpu_info}")
88
+ print(f"Max turns: {obs.max_turns}")
89
+
90
+ # %% [markdown]
91
+ # ## Reward Functions
92
+ #
93
+ # We define multiple reward signals to guide the model:
94
+ # - **Compilation reward**: +0.1 for successful compilation
95
+ # - **Correctness reward**: +0.3 for matching reference output
96
+ # - **Speedup reward**: Scaled reward for beating baseline performance
97
+
98
+ # %%
99
+ import math
100
+
101
+ def reward_compilation(completions: list[str], **kwargs) -> list[float]:
102
+ """Reward for successful compilation."""
103
+ compilation_success = kwargs.get("compilation_success", [])
104
+ return [0.1 if success else 0.0 for success in compilation_success]
105
+
106
+ def reward_correctness(completions: list[str], **kwargs) -> list[float]:
107
+ """Reward for correct output."""
108
+ correctness_pass = kwargs.get("correctness_pass", [])
109
+ return [0.3 if correct else 0.0 for correct in correctness_pass]
110
+
111
+ def reward_speedup(completions: list[str], **kwargs) -> list[float]:
112
+ """Reward scaled by speedup achieved."""
113
+ speedups = kwargs.get("speedup", [])
114
+ rewards = []
115
+ for speedup in speedups:
116
+ if speedup is None or speedup <= 0:
117
+ rewards.append(0.0)
118
+ elif speedup <= 1.0:
119
+ # Below baseline: small penalty
120
+ rewards.append(-0.1)
121
+ else:
122
+ # Above baseline: reward scales with log2(speedup)
123
+ # 2x speedup = 0.3, 4x = 0.6, 8x = 0.9
124
+ bonus = min(0.3 * math.log2(speedup), 0.6)
125
+ rewards.append(0.3 + bonus)
126
+ return rewards
127
+
128
+ def reward_combined(completions: list[str], **kwargs) -> list[float]:
129
+ """Combined reward from all signals."""
130
+ comp_rewards = reward_compilation(completions, **kwargs)
131
+ corr_rewards = reward_correctness(completions, **kwargs)
132
+ speed_rewards = reward_speedup(completions, **kwargs)
133
+ return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)]
134
+
135
+ # %% [markdown]
136
+ # ## System Prompt
137
+ #
138
+ # The system prompt provides context about the task and expected output format.
139
+
140
+ # %%
141
+ SYSTEM_PROMPT = """You are an expert GPU kernel engineer specializing in CUDA and Triton.
142
+
143
+ Your task is to optimize PyTorch operations by writing custom GPU kernels.
144
+
145
+ Guidelines:
146
+ 1. Analyze the reference PyTorch implementation carefully
147
+ 2. Identify optimization opportunities (memory access patterns, parallelism, fusion)
148
+ 3. Write a Triton or CUDA kernel that computes the same result
149
+ 4. Ensure numerical correctness (outputs must match within tolerance)
150
+
151
+ Output format:
152
+ - Provide a complete Python file
153
+ - Include a Model class with the same interface as the reference
154
+ - The Model.forward() method should use your optimized kernel
155
+ - Include all necessary imports (torch, triton, triton.language)
156
+
157
+ Focus on:
158
+ - Coalesced memory access
159
+ - Efficient use of shared memory
160
+ - Minimizing thread divergence
161
+ - Optimal block/grid dimensions"""
162
+
163
+ # %% [markdown]
164
+ # ## Rollout Function
165
+ #
166
+ # The rollout function generates kernel code and evaluates it in the environment.
167
+
168
+ # %%
169
+ def make_prompt(problem_description: str, feedback: str = "") -> str:
170
+ """Create the user prompt for the model."""
171
+ prompt = f"{problem_description}\n"
172
+ if feedback:
173
+ prompt += f"\n## Previous Attempt Feedback\n{feedback}\n"
174
+ prompt += "\nProvide your optimized kernel implementation:"
175
+ return prompt
176
+
177
+ def extract_code(completion: str) -> str:
178
+ """Extract code from model completion."""
179
+ # Handle markdown code blocks
180
+ if "```python" in completion:
181
+ start = completion.find("```python") + 9
182
+ end = completion.find("```", start)
183
+ if end > start:
184
+ return completion[start:end].strip()
185
+ if "```" in completion:
186
+ start = completion.find("```") + 3
187
+ end = completion.find("```", start)
188
+ if end > start:
189
+ return completion[start:end].strip()
190
+ # Return as-is if no code blocks
191
+ return completion.strip()
192
+
193
+ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
194
+ """
195
+ Custom rollout function for kernrl environment.
196
+
197
+ Generates kernel code and evaluates it to get rewards.
198
+ """
199
+ # Generate completions
200
+ outputs = generate_rollout_completions(trainer, prompts)
201
+
202
+ completions_text = [
203
+ tokenizer.decode(out["completion_ids"], skip_special_tokens=True)
204
+ for out in outputs
205
+ ]
206
+
207
+ # Evaluate each completion in the environment
208
+ compilation_success = []
209
+ correctness_pass = []
210
+ speedups = []
211
+
212
+ for completion in completions_text:
213
+ # Reset environment for each evaluation
214
+ obs = env.reset()
215
+
216
+ # Extract code and submit
217
+ code = extract_code(completion)
218
+ action = KernelAction(code=code)
219
+
220
+ try:
221
+ result = env.step(action)
222
+ obs = result.observation
223
+
224
+ compilation_success.append(obs.compilation_success)
225
+ correctness_pass.append(obs.correctness_pass or False)
226
+ speedups.append(obs.speedup)
227
+ except Exception as e:
228
+ print(f"Evaluation error: {e}")
229
+ compilation_success.append(False)
230
+ correctness_pass.append(False)
231
+ speedups.append(None)
232
+
233
+ return {
234
+ "prompt_ids": [out["prompt_ids"] for out in outputs],
235
+ "completion_ids": [out["completion_ids"] for out in outputs],
236
+ "logprobs": [out["logprobs"] for out in outputs],
237
+ # Pass reward signals to reward functions
238
+ "compilation_success": compilation_success,
239
+ "correctness_pass": correctness_pass,
240
+ "speedup": speedups,
241
+ }
242
+
243
+ # %% [markdown]
244
+ # ## Create Training Dataset
245
+ #
246
+ # We create a dataset from kernrl problems. Each problem becomes a training prompt.
247
+
248
+ # %%
249
+ def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset:
250
+ """Create training dataset from kernrl problems."""
251
+ prompts = []
252
+ problem_ids = []
253
+
254
+ # Get all problem IDs
255
+ all_problems = env.list_problems()
256
+
257
+ for problem_id in all_problems:
258
+ # Filter by level
259
+ level = int(problem_id.split("_")[0][1:]) # Extract level from "L1_..."
260
+ if level not in levels:
261
+ continue
262
+
263
+ # Reset to get problem description
264
+ obs = env.reset(problem_id=problem_id)
265
+
266
+ # Create prompt
267
+ messages = [
268
+ {"role": "system", "content": SYSTEM_PROMPT},
269
+ {"role": "user", "content": make_prompt(obs.problem_description)},
270
+ ]
271
+ prompt = tokenizer.apply_chat_template(
272
+ messages,
273
+ add_generation_prompt=True,
274
+ tokenize=False,
275
+ )
276
+
277
+ prompts.append(prompt)
278
+ problem_ids.append(problem_id)
279
+
280
+ return Dataset.from_dict({
281
+ "prompt": prompts,
282
+ "problem_id": problem_ids,
283
+ })
284
+
285
+ # Create dataset from Level 1 and 2 problems
286
+ dataset = create_dataset(env, levels=[1, 2])
287
+ print(f"Created dataset with {len(dataset)} problems")
288
+
289
+ # %% [markdown]
290
+ # ## Configure Training
291
+ #
292
+ # Set up GRPOTrainer with our custom rollout function and reward signals.
293
+
294
+ # %%
295
+ # Training configuration
296
+ config = GRPOConfig(
297
+ output_dir="./kernrl_grpo_output",
298
+
299
+ # vLLM settings
300
+ use_vllm=True,
301
+ vllm_mode="colocate", # Use "server" mode for multi-GPU
302
+
303
+ # Generation settings
304
+ num_generations=4, # Generations per prompt
305
+ max_completion_length=2048, # Kernel code can be long
306
+ temperature=0.7,
307
+
308
+ # Training settings
309
+ num_train_epochs=3,
310
+ per_device_train_batch_size=2,
311
+ gradient_accumulation_steps=4,
312
+ learning_rate=1e-5,
313
+
314
+ # Logging
315
+ logging_steps=10,
316
+ save_steps=100,
317
+ report_to="wandb", # Optional: log to Weights & Biases
318
+ )
319
+
320
+ # %% [markdown]
321
+ # ## Initialize Trainer
322
+
323
+ # %%
324
+ trainer = GRPOTrainer(
325
+ model=MODEL_ID,
326
+ processing_class=tokenizer,
327
+ reward_funcs=[
328
+ reward_compilation,
329
+ reward_correctness,
330
+ reward_speedup,
331
+ ],
332
+ train_dataset=dataset,
333
+ rollout_func=rollout_func,
334
+ args=config,
335
+ )
336
+
337
+ # %% [markdown]
338
+ # ## Train!
339
+ #
340
+ # Start the training loop. The model will learn to write faster kernels through
341
+ # environment feedback.
342
+
343
+ # %%
344
+ # Start training
345
+ trainer.train()
346
+
347
+ # Save the final model
348
+ trainer.save_model("./kernrl_trained_model")
349
+
350
+ # %% [markdown]
351
+ # ## Evaluate the Trained Model
352
+ #
353
+ # Test the trained model on some problems to see how well it learned.
354
+
355
+ # %%
356
+ def evaluate_model(model_path: str, problem_ids: list[str]) -> dict:
357
+ """Evaluate a trained model on kernel optimization problems."""
358
+ from transformers import AutoModelForCausalLM
359
+
360
+ model = AutoModelForCausalLM.from_pretrained(model_path)
361
+ model.eval()
362
+
363
+ results = []
364
+
365
+ for problem_id in problem_ids:
366
+ obs = env.reset(problem_id=problem_id)
367
+
368
+ # Generate kernel code
369
+ messages = [
370
+ {"role": "system", "content": SYSTEM_PROMPT},
371
+ {"role": "user", "content": make_prompt(obs.problem_description)},
372
+ ]
373
+ prompt = tokenizer.apply_chat_template(
374
+ messages,
375
+ add_generation_prompt=True,
376
+ tokenize=False,
377
+ )
378
+
379
+ inputs = tokenizer(prompt, return_tensors="pt")
380
+ with torch.no_grad():
381
+ outputs = model.generate(
382
+ **inputs,
383
+ max_new_tokens=2048,
384
+ temperature=0.3, # Lower temp for evaluation
385
+ do_sample=True,
386
+ )
387
+
388
+ completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
389
+ code = extract_code(completion)
390
+
391
+ # Evaluate
392
+ result = env.step(KernelAction(code=code))
393
+ obs = result.observation
394
+
395
+ results.append({
396
+ "problem_id": problem_id,
397
+ "compilation": obs.compilation_success,
398
+ "correctness": obs.correctness_pass,
399
+ "speedup": obs.speedup,
400
+ })
401
+
402
+ print(f"{problem_id}: compile={obs.compilation_success}, "
403
+ f"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x"
404
+ if obs.speedup else f"{problem_id}: compile={obs.compilation_success}")
405
+
406
+ return results
407
+
408
+ # Evaluate on a few problems
409
+ # eval_results = evaluate_model("./kernrl_trained_model", ["L1_23_Softmax", "L1_26_GELU_"])
410
+
411
+ # %% [markdown]
412
+ # ## Running with Server Mode (Multi-GPU)
413
+ #
414
+ # For larger models or faster training, use vLLM in server mode:
415
+ #
416
+ # ```bash
417
+ # # Terminal 1: Start vLLM server
418
+ # CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-Coder-7B-Instruct
419
+ #
420
+ # # Terminal 2: Start kernrl environment
421
+ # CUDA_VISIBLE_DEVICES=1 uvicorn kernrl.server.app:app --host 0.0.0.0 --port 8000
422
+ #
423
+ # # Terminal 3: Run training
424
+ # CUDA_VISIBLE_DEVICES=2 python train_kernrl.py --vllm-mode server --vllm-server-url http://localhost:8000
425
+ # ```
426
+ #
427
+ # Update the config:
428
+ # ```python
429
+ # config = GRPOConfig(
430
+ # use_vllm=True,
431
+ # vllm_mode="server",
432
+ # vllm_server_base_url="http://localhost:8000",
433
+ # ...
434
+ # )
435
+ # ```
436
+
437
+ # %% [markdown]
438
+ # ## Tips for Better Results
439
+ #
440
+ # 1. **Start with simpler problems**: Level 1 problems (matmul, softmax) are easier
441
+ # 2. **Use code-focused models**: Qwen2.5-Coder, DeepSeek-Coder work well
442
+ # 3. **Increase generations**: More generations per prompt = better advantage estimates
443
+ # 4. **Multi-turn training**: Let the model iterate based on feedback
444
+ # 5. **Curriculum learning**: Start with L1, add harder problems gradually
445
+
446
+ # %% [markdown]
447
+ # ## Resources
448
+ #
449
+ # - [kernrl HuggingFace Space](https://huggingface.co/spaces/Infatoshi/kernrl)
450
+ # - [OpenEnv Repository](https://github.com/meta-pytorch/OpenEnv)
451
+ # - [TRL Documentation](https://huggingface.co/docs/trl)
452
+ # - [Triton Tutorial](https://triton-lang.org/main/getting-started/tutorials/)