File size: 4,614 Bytes
c780f59
 
c024cd7
 
c780f59
c024cd7
 
 
 
c780f59
 
c024cd7
 
b1c9f54
c024cd7
aebc8f0
c024cd7
 
 
 
 
 
 
c780f59
 
 
 
 
 
 
 
 
 
c024cd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bf3992
 
 
 
 
 
 
 
c024cd7
 
 
 
 
 
 
0bf3992
 
 
 
 
 
 
 
 
 
 
 
 
 
c780f59
c024cd7
aebc8f0
c780f59
c024cd7
 
 
 
 
 
c780f59
c024cd7
 
 
 
 
 
 
 
 
 
aebc8f0
c024cd7
 
 
 
 
 
 
 
 
 
 
aebc8f0
c024cd7
 
 
c780f59
aebc8f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c024cd7
c780f59
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import sys
from typing import List, Optional

from dotenv import load_dotenv
from openai import OpenAI

from env_server import KernelOptimization_env, TASKS, grade_episode
from models import Action

load_dotenv()

API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.5")
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
TASK_NAME = os.getenv("TASK_ID")
BENCHMARK = "kernel_optimization"


def one_line(text: str) -> str:
    return " ".join((text or "").split())


def extract_code(text: str) -> str:
    if "```" not in text:
        return text
    start = text.find("```")
    end = text.rfind("```")
    chunk = text[start + 3 : end]
    if chunk.startswith("cuda") or chunk.startswith("cpp"):
        return chunk.split("\n", 1)[1]
    return chunk


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = one_line(error) if error else "null"
    done_val = str(done).lower()
    action_val = one_line(action)
    print(
        f"[STEP] step={step} action={action_val} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
        flush=True,
    )


def fallback_action(observation: dict) -> Action:
    # Deterministic, compile-safe fallback when remote model is unavailable.
    return Action(optimized_code=observation["current_best_code"], strategy="fallback")


def choose_action(client: Optional[OpenAI], observation: dict) -> Action:
    if client is None:
        return fallback_action(observation)
    prompt = (
        "Optimize this CUDA kernel.\n"
        f"Task: {observation['task_name']}\n"
        f"Pending checks: {observation['pending_checks']}\n"
        f"Current code:\n{observation['current_best_code']}\n"
        "Return only optimized CUDA code."
    )
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            temperature=0.0,
            messages=[
                {"role": "system", "content": "You are a CUDA optimization expert. Return code only."},
                {"role": "user", "content": prompt},
            ],
        )
        content = (completion.choices[0].message.content or "").strip()
        code = extract_code(content).strip() or observation["current_best_code"]
        return Action(optimized_code=code, strategy="llm_proposed")
    except Exception:
        return fallback_action(observation)


def run_episode(client: Optional[OpenAI], task_id: str) -> None:
    env = KernelOptimization_env()
    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
    try:
        obs = env.reset(task_id=task_id)["observation"]
        done = False
        while not done:
            action = choose_action(client, obs)
            step_result = env.step(action)
            done = step_result.done
            obs = step_result.observation.model_dump()
            reward = step_result.reward.value
            rewards.append(reward)
            steps_taken = obs["step_count"]
            log_step(step=steps_taken, action=action.optimized_code, reward=reward, done=done, error=None)

        score = grade_episode(
            task_id,
            env.state.completed_checks,
            env.state.best_speedup,
            env.state.step_count,
            env.state.max_steps,
        )
        score = min(max(score, 0.0), 1.0)
        success = score >= 0.1
    except Exception as exc:
        log_step(step=max(1, steps_taken + 1), action="error", reward=0.0, done=True, error=str(exc))
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)


def main() -> int:
    client: Optional[OpenAI] = None
    if API_KEY:
        try:
            client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
        except Exception:
            client = None

    if TASK_NAME and TASK_NAME in TASKS:
        task_ids = [TASK_NAME]
    else:
        task_ids = list(TASKS.keys())

    for task_id in task_ids:
        run_episode(client, task_id)
    return 0


if __name__ == "__main__":
    sys.exit(main())