|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from datasets import Dataset |
|
|
from transformers import AutoTokenizer |
|
|
from trl import GRPOConfig, GRPOTrainer |
|
|
from trl.experimental.openenv import generate_rollout_completions |
|
|
|
|
|
|
|
|
from kernrl import kernrl_env, KernelAction, KernelObservation |
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct" |
|
|
ENV_URL = "http://localhost:8000" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env = kernrl_env(base_url=ENV_URL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
obs = env.reset(problem_id="L1_23_Softmax") |
|
|
print(f"Problem: {obs.problem_id}") |
|
|
print(f"GPU: {obs.gpu_info}") |
|
|
print(f"Max turns: {obs.max_turns}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
def reward_compilation(completions: list[str], **kwargs) -> list[float]: |
|
|
"""Reward for successful compilation.""" |
|
|
compilation_success = kwargs.get("compilation_success", []) |
|
|
return [0.1 if success else 0.0 for success in compilation_success] |
|
|
|
|
|
def reward_correctness(completions: list[str], **kwargs) -> list[float]: |
|
|
"""Reward for correct output.""" |
|
|
correctness_pass = kwargs.get("correctness_pass", []) |
|
|
return [0.3 if correct else 0.0 for correct in correctness_pass] |
|
|
|
|
|
def reward_speedup(completions: list[str], **kwargs) -> list[float]: |
|
|
"""Reward scaled by speedup achieved.""" |
|
|
speedups = kwargs.get("speedup", []) |
|
|
rewards = [] |
|
|
for speedup in speedups: |
|
|
if speedup is None or speedup <= 0: |
|
|
rewards.append(0.0) |
|
|
elif speedup <= 1.0: |
|
|
|
|
|
rewards.append(-0.1) |
|
|
else: |
|
|
|
|
|
|
|
|
bonus = min(0.3 * math.log2(speedup), 0.6) |
|
|
rewards.append(0.3 + bonus) |
|
|
return rewards |
|
|
|
|
|
def reward_combined(completions: list[str], **kwargs) -> list[float]: |
|
|
"""Combined reward from all signals.""" |
|
|
comp_rewards = reward_compilation(completions, **kwargs) |
|
|
corr_rewards = reward_correctness(completions, **kwargs) |
|
|
speed_rewards = reward_speedup(completions, **kwargs) |
|
|
return [c + r + s for c, r, s in zip(comp_rewards, corr_rewards, speed_rewards)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """You are an expert GPU kernel engineer specializing in CUDA and Triton. |
|
|
|
|
|
Your task is to optimize PyTorch operations by writing custom GPU kernels. |
|
|
|
|
|
Guidelines: |
|
|
1. Analyze the reference PyTorch implementation carefully |
|
|
2. Identify optimization opportunities (memory access patterns, parallelism, fusion) |
|
|
3. Write a Triton or CUDA kernel that computes the same result |
|
|
4. Ensure numerical correctness (outputs must match within tolerance) |
|
|
|
|
|
Output format: |
|
|
- Provide a complete Python file |
|
|
- Include a Model class with the same interface as the reference |
|
|
- The Model.forward() method should use your optimized kernel |
|
|
- Include all necessary imports (torch, triton, triton.language) |
|
|
|
|
|
Focus on: |
|
|
- Coalesced memory access |
|
|
- Efficient use of shared memory |
|
|
- Minimizing thread divergence |
|
|
- Optimal block/grid dimensions""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_prompt(problem_description: str, feedback: str = "") -> str: |
|
|
"""Create the user prompt for the model.""" |
|
|
prompt = f"{problem_description}\n" |
|
|
if feedback: |
|
|
prompt += f"\n## Previous Attempt Feedback\n{feedback}\n" |
|
|
prompt += "\nProvide your optimized kernel implementation:" |
|
|
return prompt |
|
|
|
|
|
def extract_code(completion: str) -> str: |
|
|
"""Extract code from model completion.""" |
|
|
|
|
|
if "```python" in completion: |
|
|
start = completion.find("```python") + 9 |
|
|
end = completion.find("```", start) |
|
|
if end > start: |
|
|
return completion[start:end].strip() |
|
|
if "```" in completion: |
|
|
start = completion.find("```") + 3 |
|
|
end = completion.find("```", start) |
|
|
if end > start: |
|
|
return completion[start:end].strip() |
|
|
|
|
|
return completion.strip() |
|
|
|
|
|
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: |
|
|
""" |
|
|
Custom rollout function for kernrl environment. |
|
|
|
|
|
Generates kernel code and evaluates it to get rewards. |
|
|
""" |
|
|
|
|
|
outputs = generate_rollout_completions(trainer, prompts) |
|
|
|
|
|
completions_text = [ |
|
|
tokenizer.decode(out["completion_ids"], skip_special_tokens=True) |
|
|
for out in outputs |
|
|
] |
|
|
|
|
|
|
|
|
compilation_success = [] |
|
|
correctness_pass = [] |
|
|
speedups = [] |
|
|
|
|
|
for completion in completions_text: |
|
|
|
|
|
obs = env.reset() |
|
|
|
|
|
|
|
|
code = extract_code(completion) |
|
|
action = KernelAction(code=code) |
|
|
|
|
|
try: |
|
|
result = env.step(action) |
|
|
obs = result.observation |
|
|
|
|
|
compilation_success.append(obs.compilation_success) |
|
|
correctness_pass.append(obs.correctness_pass or False) |
|
|
speedups.append(obs.speedup) |
|
|
except Exception as e: |
|
|
print(f"Evaluation error: {e}") |
|
|
compilation_success.append(False) |
|
|
correctness_pass.append(False) |
|
|
speedups.append(None) |
|
|
|
|
|
return { |
|
|
"prompt_ids": [out["prompt_ids"] for out in outputs], |
|
|
"completion_ids": [out["completion_ids"] for out in outputs], |
|
|
"logprobs": [out["logprobs"] for out in outputs], |
|
|
|
|
|
"compilation_success": compilation_success, |
|
|
"correctness_pass": correctness_pass, |
|
|
"speedup": speedups, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(env: kernrl_env, levels: list[int] = [1, 2]) -> Dataset: |
|
|
"""Create training dataset from kernrl problems.""" |
|
|
prompts = [] |
|
|
problem_ids = [] |
|
|
|
|
|
|
|
|
all_problems = env.list_problems() |
|
|
|
|
|
for problem_id in all_problems: |
|
|
|
|
|
level = int(problem_id.split("_")[0][1:]) |
|
|
if level not in levels: |
|
|
continue |
|
|
|
|
|
|
|
|
obs = env.reset(problem_id=problem_id) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": make_prompt(obs.problem_description)}, |
|
|
] |
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
) |
|
|
|
|
|
prompts.append(prompt) |
|
|
problem_ids.append(problem_id) |
|
|
|
|
|
return Dataset.from_dict({ |
|
|
"prompt": prompts, |
|
|
"problem_id": problem_ids, |
|
|
}) |
|
|
|
|
|
|
|
|
dataset = create_dataset(env, levels=[1, 2]) |
|
|
print(f"Created dataset with {len(dataset)} problems") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = GRPOConfig( |
|
|
output_dir="./kernrl_grpo_output", |
|
|
|
|
|
|
|
|
use_vllm=True, |
|
|
vllm_mode="colocate", |
|
|
|
|
|
|
|
|
num_generations=4, |
|
|
max_completion_length=2048, |
|
|
temperature=0.7, |
|
|
|
|
|
|
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=2, |
|
|
gradient_accumulation_steps=4, |
|
|
learning_rate=1e-5, |
|
|
|
|
|
|
|
|
logging_steps=10, |
|
|
save_steps=100, |
|
|
report_to="wandb", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = GRPOTrainer( |
|
|
model=MODEL_ID, |
|
|
processing_class=tokenizer, |
|
|
reward_funcs=[ |
|
|
reward_compilation, |
|
|
reward_correctness, |
|
|
reward_speedup, |
|
|
], |
|
|
train_dataset=dataset, |
|
|
rollout_func=rollout_func, |
|
|
args=config, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
trainer.save_model("./kernrl_trained_model") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_model(model_path: str, problem_ids: list[str]) -> dict: |
|
|
"""Evaluate a trained model on kernel optimization problems.""" |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
model.eval() |
|
|
|
|
|
results = [] |
|
|
|
|
|
for problem_id in problem_ids: |
|
|
obs = env.reset(problem_id=problem_id) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": make_prompt(obs.problem_description)}, |
|
|
] |
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=2048, |
|
|
temperature=0.3, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
completion = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
code = extract_code(completion) |
|
|
|
|
|
|
|
|
result = env.step(KernelAction(code=code)) |
|
|
obs = result.observation |
|
|
|
|
|
results.append({ |
|
|
"problem_id": problem_id, |
|
|
"compilation": obs.compilation_success, |
|
|
"correctness": obs.correctness_pass, |
|
|
"speedup": obs.speedup, |
|
|
}) |
|
|
|
|
|
print(f"{problem_id}: compile={obs.compilation_success}, " |
|
|
f"correct={obs.correctness_pass}, speedup={obs.speedup:.2f}x" |
|
|
if obs.speedup else f"{problem_id}: compile={obs.compilation_success}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|