File size: 4,614 Bytes
c780f59 c024cd7 c780f59 c024cd7 c780f59 c024cd7 b1c9f54 c024cd7 aebc8f0 c024cd7 c780f59 c024cd7 0bf3992 c024cd7 0bf3992 c780f59 c024cd7 aebc8f0 c780f59 c024cd7 c780f59 c024cd7 aebc8f0 c024cd7 aebc8f0 c024cd7 c780f59 aebc8f0 c024cd7 c780f59 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import os
import sys
from typing import List, Optional
from dotenv import load_dotenv
from openai import OpenAI
from env_server import KernelOptimization_env, TASKS, grade_episode
from models import Action
load_dotenv()
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.5")
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
TASK_NAME = os.getenv("TASK_ID")
BENCHMARK = "kernel_optimization"
def one_line(text: str) -> str:
return " ".join((text or "").split())
def extract_code(text: str) -> str:
if "```" not in text:
return text
start = text.find("```")
end = text.rfind("```")
chunk = text[start + 3 : end]
if chunk.startswith("cuda") or chunk.startswith("cpp"):
return chunk.split("\n", 1)[1]
return chunk
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = one_line(error) if error else "null"
done_val = str(done).lower()
action_val = one_line(action)
print(
f"[STEP] step={step} action={action_val} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
flush=True,
)
def fallback_action(observation: dict) -> Action:
# Deterministic, compile-safe fallback when remote model is unavailable.
return Action(optimized_code=observation["current_best_code"], strategy="fallback")
def choose_action(client: Optional[OpenAI], observation: dict) -> Action:
if client is None:
return fallback_action(observation)
prompt = (
"Optimize this CUDA kernel.\n"
f"Task: {observation['task_name']}\n"
f"Pending checks: {observation['pending_checks']}\n"
f"Current code:\n{observation['current_best_code']}\n"
"Return only optimized CUDA code."
)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
temperature=0.0,
messages=[
{"role": "system", "content": "You are a CUDA optimization expert. Return code only."},
{"role": "user", "content": prompt},
],
)
content = (completion.choices[0].message.content or "").strip()
code = extract_code(content).strip() or observation["current_best_code"]
return Action(optimized_code=code, strategy="llm_proposed")
except Exception:
return fallback_action(observation)
def run_episode(client: Optional[OpenAI], task_id: str) -> None:
env = KernelOptimization_env()
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
obs = env.reset(task_id=task_id)["observation"]
done = False
while not done:
action = choose_action(client, obs)
step_result = env.step(action)
done = step_result.done
obs = step_result.observation.model_dump()
reward = step_result.reward.value
rewards.append(reward)
steps_taken = obs["step_count"]
log_step(step=steps_taken, action=action.optimized_code, reward=reward, done=done, error=None)
score = grade_episode(
task_id,
env.state.completed_checks,
env.state.best_speedup,
env.state.step_count,
env.state.max_steps,
)
score = min(max(score, 0.0), 1.0)
success = score >= 0.1
except Exception as exc:
log_step(step=max(1, steps_taken + 1), action="error", reward=0.0, done=True, error=str(exc))
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
def main() -> int:
client: Optional[OpenAI] = None
if API_KEY:
try:
client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
except Exception:
client = None
if TASK_NAME and TASK_NAME in TASKS:
task_ids = [TASK_NAME]
else:
task_ids = list(TASKS.keys())
for task_id in task_ids:
run_episode(client, task_id)
return 0
if __name__ == "__main__":
sys.exit(main())
|