# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # kernelspec: # display_name: Python 3 # language: python # name: python3 # --- # %% [markdown] # # Training LLMs to Write Fast GPU Kernels with GRPO # # This notebook demonstrates how to train a language model to write optimized CUDA/Triton # kernels using TRL's GRPOTrainer and the kernrl OpenEnv environment. # # **What is kernrl?** # - An RL environment for GPU kernel optimization # - Agents receive PyTorch reference implementations # - Must write faster CUDA/Triton kernels that produce correct outputs # - Rewards based on compilation success, correctness, and speedup # # **What is GRPO?** # - Group Relative Policy Optimization # - Efficient RL algorithm for training LLMs # - Uses multiple generations per prompt to estimate advantages # - Works well with environment-based reward signals # %% [markdown] # ## Installation # # First, install the required packages: # %% # !pip install torch triton trl transformers accelerate # !pip install git+https://github.com/meta-pytorch/OpenEnv.git # %% [markdown] # ## Setup # # Import necessary libraries and configure the environment. # %% import torch from datasets import Dataset from transformers import AutoTokenizer from trl import GRPOConfig, GRPOTrainer from trl.experimental.openenv import generate_rollout_completions # Import kernrl environment from kernrl import kernrl_env, KernelAction, KernelObservation # %% # Configuration MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct" # Good for code generation ENV_URL = "http://localhost:8000" # kernrl server URL # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # %% [markdown] # ## Connect to kernrl Environment # # The kernrl environment evaluates submitted kernels for: # 1. **Compilation**: Does the code compile? # 2. **Correctness**: Does output match reference (within tolerance)? # 3. **Performance**: Is it faster than PyTorch baseline? # %% # Connect to the kernrl server # Option 1: Connect to running server env = kernrl_env(base_url=ENV_URL) # Option 2: Load from HuggingFace Hub (requires GPU) # env = kernrl_env.from_hub("Infatoshi/kernrl") # Option 3: Local Docker # env = kernrl_env.from_docker_image("kernrl:latest") # Test the connection obs = env.reset(problem_id="L1_23_Softmax") print(f"Problem: {obs.problem_id}") print(f"GPU: {obs.gpu_info}") print(f"Max turns: {obs.max_turns}") # %% [markdown] # ## Reward Functions # # We define multiple reward signals to guide the model: # - **Compilation reward**: +0.1 for successful compilation # - **Correctness reward**: +0.3 for matching reference output # - **Speedup reward**: Scaled reward for beating baseline performance # %% import math def reward_compilation(completions: list[str], **kwargs) -> list[float]: """Reward for successful compilation.""" compilation_success = kwargs.get("compilation_success", []) return [0.1 if success else 0.0 for success in compilation_success] def reward_correctness(completions: list[str], **kwargs) -> list[float]: """Reward for correct output.""" correctness_pass = kwargs.get("correctness_pass", []) return [0.3 if correct else 0.0 for correct in correctness_pass] def reward_speedup(completions: list[str], **kwargs) -> list[float]: """Reward scaled by speedup achieved.""" speedups = kwargs.get("speedup", []) rewards = [] for speedup in speedups: if speedup is None or speedup <= 0: rewards.append(0.0) elif speedup <= 1.0: # Below baseline: small penalty rewards.append(-0.1) else: # Above baseline: reward scales with log2(speedup) # 2x speedup = 0.3, 4x = 0.6, 8x = 0.9 bonus = min(0.3 * math.log2(speedup), 0.6) rewards.append(0.3 + bonus) return rewards def reward_combined(completions: list[str], **kwargs) -> list[float]: """Combined reward from all signals.""" comp_rewards = reward_compilation(completions, **kwargs) corr_rewards = reward_correctness(completions, **kwargs) speed_rewards = reward_speedup(completions, **kwargs) return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)] # %% [markdown] # ## System Prompt # # The system prompt provides context about the task and expected output format. # %% SYSTEM_PROMPT = """You are an expert GPU kernel engineer specializing in CUDA and Triton. Your task is to optimize PyTorch operations by writing custom GPU kernels. Guidelines: 1. Analyze the reference PyTorch implementation carefully 2. Identify optimization opportunities (memory access patterns, parallelism, fusion) 3. Write a Triton or CUDA kernel that computes the same result 4. Ensure numerical correctness (outputs must match within tolerance) Output format: - Provide a complete Python file - Include a Model class with the same interface as the reference - The Model.forward() method should use your optimized kernel - Include all necessary imports (torch, triton, triton.language) Focus on: - Coalesced memory access - Efficient use of shared memory - Minimizing thread divergence - Optimal block/grid dimensions""" # %% [markdown] # ## Rollout Function # # The rollout function generates kernel code and evaluates it in the environment. # %% def make_prompt(problem_description: str, feedback: str = "") -> str: """Create the user prompt for the model.""" prompt = f"{problem_description}\n" if feedback: prompt += f"\n## Previous Attempt Feedback\n{feedback}\n" prompt += "\nProvide your optimized kernel implementation:" return prompt def extract_code(completion: str) -> str: """Extract code from model completion.""" # Handle markdown code blocks if "```python" in completion: start = completion.find("```python") + 9 end = completion.find("```", start) if end > start: return completion[start:end].strip() if "```" in completion: start = completion.find("```") + 3 end = completion.find("```", start) if end > start: return completion[start:end].strip() # Return as-is if no code blocks return completion.strip() def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: """ Custom rollout function for kernrl environment. Generates kernel code and evaluates it to get rewards. """ # Generate completions outputs = generate_rollout_completions(trainer, prompts) completions_text = [ tokenizer.decode(out["completion_ids"], skip_special_tokens=True) for out in outputs ] # Evaluate each completion in the environment compilation_success = [] correctness_pass = [] speedups = [] for completion in completions_text: # Reset environment for each evaluation obs = env.reset() # Extract code and submit code = extract_code(completion) action = KernelAction(code=code) try: result = env.step(action) obs = result.observation compilation_success.append(obs.compilation_success) correctness_pass.append(obs.correctness_pass or False) speedups.append(obs.speedup) except Exception as e: print(f"Evaluation error: {e}") compilation_success.append(False) correctness_pass.append(False) speedups.append(None) return { "prompt_ids": [out["prompt_ids"] for out in outputs], "completion_ids": [out["completion_ids"] for out in outputs], "logprobs": [out["logprobs"] for out in outputs], # Pass reward signals to reward functions "compilation_success": compilation_success, "correctness_pass": correctness_pass, "speedup": speedups, } # %% [markdown] # ## Create Training Dataset # # We create a dataset from kernrl problems. Each problem becomes a training prompt. # %% def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset: """Create training dataset from kernrl problems.""" prompts = [] problem_ids = [] # Get all problem IDs all_problems = env.list_problems() for problem_id in all_problems: # Filter by level level = int(problem_id.split("_")[0][1:]) # Extract level from "L1_..." if level not in levels: continue # Reset to get problem description obs = env.reset(problem_id=problem_id) # Create prompt messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": make_prompt(obs.problem_description)}, ] prompt = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) prompts.append(prompt) problem_ids.append(problem_id) return Dataset.from_dict({ "prompt": prompts, "problem_id": problem_ids, }) # Create dataset from Level 1 and 2 problems dataset = create_dataset(env, levels=[1, 2]) print(f"Created dataset with {len(dataset)} problems") # %% [markdown] # ## Configure Training # # Set up GRPOTrainer with our custom rollout function and reward signals. # %% # Training configuration config = GRPOConfig( output_dir="./kernrl_grpo_output", # vLLM settings use_vllm=True, vllm_mode="colocate", # Use "server" mode for multi-GPU # Generation settings num_generations=4, # Generations per prompt max_completion_length=2048, # Kernel code can be long temperature=0.7, # Training settings num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=1e-5, # Logging logging_steps=10, save_steps=100, report_to="wandb", # Optional: log to Weights & Biases ) # %% [markdown] # ## Initialize Trainer # %% trainer = GRPOTrainer( model=MODEL_ID, processing_class=tokenizer, reward_funcs=[ reward_compilation, reward_correctness, reward_speedup, ], train_dataset=dataset, rollout_func=rollout_func, args=config, ) # %% [markdown] # ## Train! # # Start the training loop. The model will learn to write faster kernels through # environment feedback. # %% # Start training trainer.train() # Save the final model trainer.save_model("./kernrl_trained_model") # %% [markdown] # ## Evaluate the Trained Model # # Test the trained model on some problems to see how well it learned. # %% def evaluate_model(model_path: str, problem_ids: list[str]) -> dict: """Evaluate a trained model on kernel optimization problems.""" from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(model_path) model.eval() results = [] for problem_id in problem_ids: obs = env.reset(problem_id=problem_id) # Generate kernel code messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": make_prompt(obs.problem_description)}, ] prompt = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=2048, temperature=0.3, # Lower temp for evaluation do_sample=True, ) completion = tokenizer.decode(outputs[0], skip_special_tokens=True) code = extract_code(completion) # Evaluate result = env.step(KernelAction(code=code)) obs = result.observation results.append({ "problem_id": problem_id, "compilation": obs.compilation_success, "correctness": obs.correctness_pass, "speedup": obs.speedup, }) print(f"{problem_id}: compile={obs.compilation_success}, " f"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x" if obs.speedup else f"{problem_id}: compile={obs.compilation_success}") return results # Evaluate on a few problems # eval_results = evaluate_model("./kernrl_trained_model", ["L1_23_Softmax", "L1_26_GELU_"]) # %% [markdown] # ## Running with Server Mode (Multi-GPU) # # For larger models or faster training, use vLLM in server mode: # # ```bash # # Terminal 1: Start vLLM server # CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-Coder-7B-Instruct # # # Terminal 2: Start kernrl environment # CUDA_VISIBLE_DEVICES=1 uvicorn kernrl.server.app:app --host 0.0.0.0 --port 8000 # # # Terminal 3: Run training # CUDA_VISIBLE_DEVICES=2 python train_kernrl.py --vllm-mode server --vllm-server-url http://localhost:8000 # ``` # # Update the config: # ```python # config = GRPOConfig( # use_vllm=True, # vllm_mode="server", # vllm_server_base_url="http://localhost:8000", # ... # ) # ``` # %% [markdown] # ## Tips for Better Results # # 1. **Start with simpler problems**: Level 1 problems (matmul, softmax) are easier # 2. **Use code-focused models**: Qwen2.5-Coder, DeepSeek-Coder work well # 3. **Increase generations**: More generations per prompt = better advantage estimates # 4. **Multi-turn training**: Let the model iterate based on feedback # 5. **Curriculum learning**: Start with L1, add harder problems gradually # %% [markdown] # ## Resources # # - [kernrl HuggingFace Space](https://huggingface.co/spaces/Infatoshi/kernrl) # - [OpenEnv Repository](https://github.com/meta-pytorch/OpenEnv) # - [TRL Documentation](https://huggingface.co/docs/trl) # - [Triton Tutorial](https://triton-lang.org/main/getting-started/tutorials/)