File size: 2,913 Bytes
c780f59
 
 
 
 
 
 
 
 
 
 
 
 
d5c6f39
 
 
 
c780f59
 
 
 
 
 
 
 
 
 
 
d5c6f39
c780f59
 
d5c6f39
 
 
c780f59
 
 
 
 
 
 
 
 
d5c6f39
 
 
 
 
 
c780f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c6f39
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from env_server import KernelOptimization_env, TASKS
from trl import GRPOConfig, GRPOTrainer
from models import Action
from typing import List
from datasets import Dataset
import os

class KernelOptTool:
    def __init__(self):
        self.env = KernelOptimization_env()
        self.reward = 0.0
        self.done = False
    
    def reset(self, **kwargs) -> str:
        task_id = kwargs.get("task_id")
        if task_id is None and isinstance(kwargs.get("sample"), dict):
            task_id = kwargs["sample"].get("task_id")
        result = self.env.reset(task_id=task_id)
        obs = result["observation"]
        self.reward = 0.0
        self.done = False
        return (
            f"Task: {obs['task_name']}\n"
            f"Baseline CUDA kernel:\n{obs['baseline_code']}\n"
            f"Pending checks: {obs['pending_checks']}\n"
            "Use tools to submit improved code."
        )

    def submit_optimization(self, optimized_code: str, strategy: str = "", expected_speedup: float | None = None) -> str:
        if self.done:
            raise ValueError("Episode is already done.")
        result = self.env.step(
            Action(optimized_code=optimized_code, strategy=strategy, expected_speedup=expected_speedup)
        )
        self.reward = result.reward.value
        self.done = result.done
        obs = result.observation
        return (
            f"reward={result.reward.value:.4f}, "
            f"best_speedup={obs.current_best_speedup:.3f}x, "
            f"pending_checks={obs.pending_checks}, done={result.done}"
        )

    # Backward-compatible alias
    def submit_optiization(self, optimized_code: str, strategy: str = "") -> str:
        return self.submit_optimization(optimized_code=optimized_code, strategy=strategy)

def reward_func(environments, **kwargs) -> List[float]:
    return [env.reward for env in environments]

def build_dataset(repeats_per_task:int=32)-> Dataset:
    prompts, task_ids = [], []
    for task_id, task in TASKS.items():
        for _ in range(repeats_per_task):
            prompts.append([{"role": "user", "content": f"Optimize CUDA kernel task: {task['name']}"}])
            task_ids.append(task_id)
    return Dataset.from_dict({"prompt": prompts, "task_id": task_ids})

def main():
    model_name =os.getenv("TRAIN_MODEL", "Qwen/Qwen3-0.6B")
    dataset = build_dataset()
    trainer = GRPOTrainer(
        model=model_name,
        train_dataset=dataset,
        reward_funcs=reward_func,
        environment_factory=KernelOptTool,
        args=GRPOConfig(
            chat_template_kwargs={"enable_thinking": False},
            max_completion_length=2048,
            num_generations=4,
            log_completions=True,
        ),
    )
    trainer.train()
# trainer = GRPOTrainer(model =model_name, train_dataset=dataset, reward_funcs =reward_func, env_factory=KernelOptTool)

if __name__ == "__main__":
    main()