kernrl-training / train_kernrl.py
Infatoshi's picture
Upload folder using huggingface_hub
367dc36 verified
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---
# %% [markdown]
# # Training LLMs to Write Fast GPU Kernels with GRPO
#
# This notebook demonstrates how to train a language model to write optimized CUDA/Triton
# kernels using TRL's GRPOTrainer and the kernrl OpenEnv environment.
#
# **What is kernrl?**
# - An RL environment for GPU kernel optimization
# - Agents receive PyTorch reference implementations
# - Must write faster CUDA/Triton kernels that produce correct outputs
# - Rewards based on compilation success, correctness, and speedup
#
# **What is GRPO?**
# - Group Relative Policy Optimization
# - Efficient RL algorithm for training LLMs
# - Uses multiple generations per prompt to estimate advantages
# - Works well with environment-based reward signals
# %% [markdown]
# ## Installation
#
# First, install the required packages:
# %%
# !pip install torch triton trl transformers accelerate
# !pip install git+https://github.com/meta-pytorch/OpenEnv.git
# %% [markdown]
# ## Setup
#
# Import necessary libraries and configure the environment.
# %%
import torch
from datasets import Dataset
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.openenv import generate_rollout_completions
# Import kernrl environment
from kernrl import kernrl_env, KernelAction, KernelObservation
# %%
# Configuration
MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct" # Good for code generation
ENV_URL = "http://localhost:8000" # kernrl server URL
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# %% [markdown]
# ## Connect to kernrl Environment
#
# The kernrl environment evaluates submitted kernels for:
# 1. **Compilation**: Does the code compile?
# 2. **Correctness**: Does output match reference (within tolerance)?
# 3. **Performance**: Is it faster than PyTorch baseline?
# %%
# Connect to the kernrl server
# Option 1: Connect to running server
env = kernrl_env(base_url=ENV_URL)
# Option 2: Load from HuggingFace Hub (requires GPU)
# env = kernrl_env.from_hub("Infatoshi/kernrl")
# Option 3: Local Docker
# env = kernrl_env.from_docker_image("kernrl:latest")
# Test the connection
obs = env.reset(problem_id="L1_23_Softmax")
print(f"Problem: {obs.problem_id}")
print(f"GPU: {obs.gpu_info}")
print(f"Max turns: {obs.max_turns}")
# %% [markdown]
# ## Reward Functions
#
# We define multiple reward signals to guide the model:
# - **Compilation reward**: +0.1 for successful compilation
# - **Correctness reward**: +0.3 for matching reference output
# - **Speedup reward**: Scaled reward for beating baseline performance
# %%
import math
def reward_compilation(completions: list[str], **kwargs) -> list[float]:
"""Reward for successful compilation."""
compilation_success = kwargs.get("compilation_success", [])
return [0.1 if success else 0.0 for success in compilation_success]
def reward_correctness(completions: list[str], **kwargs) -> list[float]:
"""Reward for correct output."""
correctness_pass = kwargs.get("correctness_pass", [])
return [0.3 if correct else 0.0 for correct in correctness_pass]
def reward_speedup(completions: list[str], **kwargs) -> list[float]:
"""Reward scaled by speedup achieved."""
speedups = kwargs.get("speedup", [])
rewards = []
for speedup in speedups:
if speedup is None or speedup <= 0:
rewards.append(0.0)
elif speedup <= 1.0:
# Below baseline: small penalty
rewards.append(-0.1)
else:
# Above baseline: reward scales with log2(speedup)
# 2x speedup = 0.3, 4x = 0.6, 8x = 0.9
bonus = min(0.3 * math.log2(speedup), 0.6)
rewards.append(0.3 + bonus)
return rewards
def reward_combined(completions: list[str], **kwargs) -> list[float]:
"""Combined reward from all signals."""
comp_rewards = reward_compilation(completions, **kwargs)
corr_rewards = reward_correctness(completions, **kwargs)
speed_rewards = reward_speedup(completions, **kwargs)
return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)]
# %% [markdown]
# ## System Prompt
#
# The system prompt provides context about the task and expected output format.
# %%
SYSTEM_PROMPT = """You are an expert GPU kernel engineer specializing in CUDA and Triton.
Your task is to optimize PyTorch operations by writing custom GPU kernels.
Guidelines:
1. Analyze the reference PyTorch implementation carefully
2. Identify optimization opportunities (memory access patterns, parallelism, fusion)
3. Write a Triton or CUDA kernel that computes the same result
4. Ensure numerical correctness (outputs must match within tolerance)
Output format:
- Provide a complete Python file
- Include a Model class with the same interface as the reference
- The Model.forward() method should use your optimized kernel
- Include all necessary imports (torch, triton, triton.language)
Focus on:
- Coalesced memory access
- Efficient use of shared memory
- Minimizing thread divergence
- Optimal block/grid dimensions"""
# %% [markdown]
# ## Rollout Function
#
# The rollout function generates kernel code and evaluates it in the environment.
# %%
def make_prompt(problem_description: str, feedback: str = "") -> str:
"""Create the user prompt for the model."""
prompt = f"{problem_description}\n"
if feedback:
prompt += f"\n## Previous Attempt Feedback\n{feedback}\n"
prompt += "\nProvide your optimized kernel implementation:"
return prompt
def extract_code(completion: str) -> str:
"""Extract code from model completion."""
# Handle markdown code blocks
if "```python" in completion:
start = completion.find("```python") + 9
end = completion.find("```", start)
if end > start:
return completion[start:end].strip()
if "```" in completion:
start = completion.find("```") + 3
end = completion.find("```", start)
if end > start:
return completion[start:end].strip()
# Return as-is if no code blocks
return completion.strip()
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
"""
Custom rollout function for kernrl environment.
Generates kernel code and evaluates it to get rewards.
"""
# Generate completions
outputs = generate_rollout_completions(trainer, prompts)
completions_text = [
tokenizer.decode(out["completion_ids"], skip_special_tokens=True)
for out in outputs
]
# Evaluate each completion in the environment
compilation_success = []
correctness_pass = []
speedups = []
for completion in completions_text:
# Reset environment for each evaluation
obs = env.reset()
# Extract code and submit
code = extract_code(completion)
action = KernelAction(code=code)
try:
result = env.step(action)
obs = result.observation
compilation_success.append(obs.compilation_success)
correctness_pass.append(obs.correctness_pass or False)
speedups.append(obs.speedup)
except Exception as e:
print(f"Evaluation error: {e}")
compilation_success.append(False)
correctness_pass.append(False)
speedups.append(None)
return {
"prompt_ids": [out["prompt_ids"] for out in outputs],
"completion_ids": [out["completion_ids"] for out in outputs],
"logprobs": [out["logprobs"] for out in outputs],
# Pass reward signals to reward functions
"compilation_success": compilation_success,
"correctness_pass": correctness_pass,
"speedup": speedups,
}
# %% [markdown]
# ## Create Training Dataset
#
# We create a dataset from kernrl problems. Each problem becomes a training prompt.
# %%
def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset:
"""Create training dataset from kernrl problems."""
prompts = []
problem_ids = []
# Get all problem IDs
all_problems = env.list_problems()
for problem_id in all_problems:
# Filter by level
level = int(problem_id.split("_")[0][1:]) # Extract level from "L1_..."
if level not in levels:
continue
# Reset to get problem description
obs = env.reset(problem_id=problem_id)
# Create prompt
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": make_prompt(obs.problem_description)},
]
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
prompts.append(prompt)
problem_ids.append(problem_id)
return Dataset.from_dict({
"prompt": prompts,
"problem_id": problem_ids,
})
# Create dataset from Level 1 and 2 problems
dataset = create_dataset(env, levels=[1, 2])
print(f"Created dataset with {len(dataset)} problems")
# %% [markdown]
# ## Configure Training
#
# Set up GRPOTrainer with our custom rollout function and reward signals.
# %%
# Training configuration
config = GRPOConfig(
output_dir="./kernrl_grpo_output",
# vLLM settings
use_vllm=True,
vllm_mode="colocate", # Use "server" mode for multi-GPU
# Generation settings
num_generations=4, # Generations per prompt
max_completion_length=2048, # Kernel code can be long
temperature=0.7,
# Training settings
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=1e-5,
# Logging
logging_steps=10,
save_steps=100,
report_to="wandb", # Optional: log to Weights & Biases
)
# %% [markdown]
# ## Initialize Trainer
# %%
trainer = GRPOTrainer(
model=MODEL_ID,
processing_class=tokenizer,
reward_funcs=[
reward_compilation,
reward_correctness,
reward_speedup,
],
train_dataset=dataset,
rollout_func=rollout_func,
args=config,
)
# %% [markdown]
# ## Train!
#
# Start the training loop. The model will learn to write faster kernels through
# environment feedback.
# %%
# Start training
trainer.train()
# Save the final model
trainer.save_model("./kernrl_trained_model")
# %% [markdown]
# ## Evaluate the Trained Model
#
# Test the trained model on some problems to see how well it learned.
# %%
def evaluate_model(model_path: str, problem_ids: list[str]) -> dict:
"""Evaluate a trained model on kernel optimization problems."""
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()
results = []
for problem_id in problem_ids:
obs = env.reset(problem_id=problem_id)
# Generate kernel code
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": make_prompt(obs.problem_description)},
]
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.3, # Lower temp for evaluation
do_sample=True,
)
completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
code = extract_code(completion)
# Evaluate
result = env.step(KernelAction(code=code))
obs = result.observation
results.append({
"problem_id": problem_id,
"compilation": obs.compilation_success,
"correctness": obs.correctness_pass,
"speedup": obs.speedup,
})
print(f"{problem_id}: compile={obs.compilation_success}, "
f"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x"
if obs.speedup else f"{problem_id}: compile={obs.compilation_success}")
return results
# Evaluate on a few problems
# eval_results = evaluate_model("./kernrl_trained_model", ["L1_23_Softmax", "L1_26_GELU_"])
# %% [markdown]
# ## Running with Server Mode (Multi-GPU)
#
# For larger models or faster training, use vLLM in server mode:
#
# ```bash
# # Terminal 1: Start vLLM server
# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-Coder-7B-Instruct
#
# # Terminal 2: Start kernrl environment
# CUDA_VISIBLE_DEVICES=1 uvicorn kernrl.server.app:app --host 0.0.0.0 --port 8000
#
# # Terminal 3: Run training
# CUDA_VISIBLE_DEVICES=2 python train_kernrl.py --vllm-mode server --vllm-server-url http://localhost:8000
# ```
#
# Update the config:
# ```python
# config = GRPOConfig(
# use_vllm=True,
# vllm_mode="server",
# vllm_server_base_url="http://localhost:8000",
# ...
# )
# ```
# %% [markdown]
# ## Tips for Better Results
#
# 1. **Start with simpler problems**: Level 1 problems (matmul, softmax) are easier
# 2. **Use code-focused models**: Qwen2.5-Coder, DeepSeek-Coder work well
# 3. **Increase generations**: More generations per prompt = better advantage estimates
# 4. **Multi-turn training**: Let the model iterate based on feedback
# 5. **Curriculum learning**: Start with L1, add harder problems gradually
# %% [markdown]
# ## Resources
#
# - [kernrl HuggingFace Space](https://huggingface.co/spaces/Infatoshi/kernrl)
# - [OpenEnv Repository](https://github.com/meta-pytorch/OpenEnv)
# - [TRL Documentation](https://huggingface.co/docs/trl)
# - [Triton Tutorial](https://triton-lang.org/main/getting-started/tutorials/)