aaloksan commited on
Commit
c024cd7
·
1 Parent(s): 98a22be

fix:update infernnce to support validator

Browse files
Files changed (1) hide show
  1. inference.py +103 -45
inference.py CHANGED
@@ -1,13 +1,26 @@
1
  import os
2
- from openai import OpenAI, AuthenticationError
3
- from typing import Dict
4
- from env_server import TASKS, KernelOptimization_env, grade_episode
5
- from models import Action
6
- import json
7
  import sys
 
 
8
  from dotenv import load_dotenv
 
 
 
 
9
 
10
  load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
11
  def extract_code(text: str) -> str:
12
  if "```" not in text:
13
  return text
@@ -18,56 +31,101 @@ def extract_code(text: str) -> str:
18
  return chunk.split("\n", 1)[1]
19
  return chunk
20
 
21
- def choose_action(client: OpenAI, model: str, observation: Dict) -> Action:
22
- prompt = f"""Optimize this CUDA kernel.
23
- Task: {observation['task_name']}
24
- Pending checks: {observation['pending_checks']}
25
- Baseline:
26
- {observation['baseline_code']}
27
- Current best speedup: {observation['current_best_speedup']}x
28
- Return only optimized CUDA code.
29
- """
30
- response = client.chat.completions.create(
31
- model=model,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  temperature=0.0,
33
  messages=[
34
  {"role": "system", "content": "You are a CUDA optimization expert. Return code only."},
35
  {"role": "user", "content": prompt},
36
  ],
37
  )
38
- text = (response.choices[0].message.content or "").strip()
39
- code = extract_code(text).strip() or observation["current_best_code"]
40
  return Action(optimized_code=code, strategy="llm_proposed")
41
 
42
- def run_task(client: OpenAI, model: str, task_id: str) -> float:
 
 
43
  env = KernelOptimization_env()
44
- obs = env.reset(task_id=task_id)["observation"]
45
- done = False
46
- while not done:
47
- action = choose_action(client, model, obs)
48
- step_result = env.step(action)
49
- obs = step_result.observation.model_dump()
50
- done = step_result.done
51
- return grade_episode(task_id, env.state.completed_checks, env.state.best_speedup, env.state.step_count, env.state.max_steps)
52
- def main()->int:
53
- if not os.getenv("OPENAI_API_KEY"):
54
- print("openai key not set")
55
-
56
- model =os.getenv("MODEL_NAME", "gemma-3-4b")
57
- client =OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url =os.getenv("API_BASE_URL", "https://api.openai.com/v1"))
58
-
59
- scores: Dict[str, float] = {}
60
  try:
61
- for task_id in TASKS:
62
- scores[task_id] = run_task(client, model, task_id)
63
- print(f"[TASK] {task_id} score={scores[task_id]:.4f}")
64
- except AuthenticationError:
65
- print("ERROR: OpenAI authentication failed. Check OPENAI_API_KEY.", file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return 1
 
 
 
67
 
68
- avg = sum(scores.values()) / len(scores)
69
- print(f"[BASELINE] model={model} average_score={avg:.4f}")
70
- print(json.dumps({"scores": scores, "average": round(avg, 4)}))
71
- return 0
72
- if __name__=="__main__":
73
  sys.exit(main())
 
1
  import os
 
 
 
 
 
2
  import sys
3
+ from typing import List, Optional
4
+
5
  from dotenv import load_dotenv
6
+ from openai import OpenAI
7
+
8
+ from env_server import KernelOptimization_env, TASKS, grade_episode
9
+ from models import Action
10
 
11
  load_dotenv()
12
+
13
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
14
+ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
15
+ API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
16
+ TASK_NAME = os.getenv("TASK_ID", "vector_add_easy")
17
+ BENCHMARK = "kernel_optimization"
18
+
19
+
20
+ def one_line(text: str) -> str:
21
+ return " ".join((text or "").split())
22
+
23
+
24
  def extract_code(text: str) -> str:
25
  if "```" not in text:
26
  return text
 
31
  return chunk.split("\n", 1)[1]
32
  return chunk
33
 
34
+
35
+ def log_start(task: str, env: str, model: str) -> None:
36
+ print(f"[START] task={task} env={env} model={model}", flush=True)
37
+
38
+
39
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
40
+ error_val = one_line(error) if error else "null"
41
+ done_val = str(done).lower()
42
+ action_val = one_line(action)
43
+ print(
44
+ f"[STEP] step={step} action={action_val} reward={reward:.2f} done={done_val} error={error_val}",
45
+ flush=True,
46
+ )
47
+
48
+
49
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
50
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
51
+ print(
52
+ f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
53
+ flush=True,
54
+ )
55
+
56
+
57
+ def choose_action(client: OpenAI, observation: dict) -> Action:
58
+ prompt = (
59
+ "Optimize this CUDA kernel.\n"
60
+ f"Task: {observation['task_name']}\n"
61
+ f"Pending checks: {observation['pending_checks']}\n"
62
+ f"Current code:\n{observation['current_best_code']}\n"
63
+ "Return only optimized CUDA code."
64
+ )
65
+ completion = client.chat.completions.create(
66
+ model=MODEL_NAME,
67
  temperature=0.0,
68
  messages=[
69
  {"role": "system", "content": "You are a CUDA optimization expert. Return code only."},
70
  {"role": "user", "content": prompt},
71
  ],
72
  )
73
+ content = (completion.choices[0].message.content or "").strip()
74
+ code = extract_code(content).strip() or observation["current_best_code"]
75
  return Action(optimized_code=code, strategy="llm_proposed")
76
 
77
+
78
+ def main() -> int:
79
+ task_id = TASK_NAME if TASK_NAME in TASKS else "vector_add_easy"
80
  env = KernelOptimization_env()
81
+ rewards: List[float] = []
82
+ steps_taken = 0
83
+ score = 0.0
84
+ success = False
85
+
86
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
87
+
 
 
 
 
 
 
 
 
 
88
  try:
89
+ if not API_KEY:
90
+ raise RuntimeError("Missing OPENAI_API_KEY")
91
+
92
+ client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
93
+ obs = env.reset(task_id=task_id)["observation"]
94
+ done = False
95
+
96
+ while not done:
97
+ action = choose_action(client, obs)
98
+ action_str = action.optimized_code
99
+ step_result = env.step(action)
100
+ done = step_result.done
101
+ obs = step_result.observation.model_dump()
102
+ reward = step_result.reward.value
103
+ rewards.append(reward)
104
+ steps_taken = obs["step_count"]
105
+ log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
106
+
107
+ score = grade_episode(
108
+ task_id,
109
+ env.state.completed_checks,
110
+ env.state.best_speedup,
111
+ env.state.step_count,
112
+ env.state.max_steps,
113
+ )
114
+ score = min(max(score, 0.0), 1.0)
115
+ success = score >= 0.1
116
+ return 0
117
+ except Exception as exc:
118
+ log_step(
119
+ step=max(1, steps_taken + 1),
120
+ action="error",
121
+ reward=0.0,
122
+ done=True,
123
+ error=str(exc),
124
+ )
125
  return 1
126
+ finally:
127
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
128
+
129
 
130
+ if __name__ == "__main__":
 
 
 
 
131
  sys.exit(main())