100XZX001 commited on
Commit
bc1b1a6
·
verified ·
1 Parent(s): f90a861

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +36 -18
inference.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  import json
3
  import textwrap
 
4
  from typing import List
5
 
6
  from openai import OpenAI
@@ -14,6 +16,9 @@ MAX_STEPS = 5
14
  TEMPERATURE = 0.2
15
  MAX_TOKENS = 200
16
  FALLBACK_ACTION = "skip"
 
 
 
17
 
18
  SYSTEM_PROMPT = textwrap.dedent(
19
  """
@@ -29,7 +34,6 @@ SYSTEM_PROMPT = textwrap.dedent(
29
  """
30
  ).strip()
31
 
32
-
33
  def build_user_prompt(step: int, obs: Observation, history: List[str]) -> str:
34
  newline = "\n"
35
  comments_str = newline.join(obs.comments) if obs.comments else "No existing comments"
@@ -55,7 +59,6 @@ def build_user_prompt(step: int, obs: Observation, history: List[str]) -> str:
55
  ).strip()
56
  return prompt
57
 
58
-
59
  def parse_model_action(response_text: str) -> Action:
60
  if not response_text:
61
  return Action(action_type=FALLBACK_ACTION)
@@ -85,17 +88,27 @@ def parse_model_action(response_text: str) -> Action:
85
  return Action(action_type="propose_fix", fix_code=fix)
86
  return Action(action_type="write_comment", comment_text=raw)
87
 
 
 
 
 
 
 
88
 
89
- def main() -> None:
 
 
 
 
 
90
  if not API_BASE_URL or not API_KEY or not MODEL_NAME:
91
- # This error message is printed to stderr? It must not appear on stdout.
92
- # Use sys.stderr to avoid polluting structured output.
93
- import sys
94
  print("Error: API_BASE_URL, HF_TOKEN/API_KEY, and MODEL_NAME must be set.", file=sys.stderr)
95
  return
96
 
97
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
98
- env = CodeReviewEnv()
 
 
99
  tasks = ["easy", "medium", "hard", "harder", "hardest"]
100
 
101
  for task in tasks:
@@ -104,10 +117,10 @@ def main() -> None:
104
  history: List[str] = []
105
  done = False
106
  step = 0
107
- final_reward = 0.0
 
108
 
109
- # START marker
110
- print(f"[START] task={task}", flush=True)
111
 
112
  while not done and step < MAX_STEPS:
113
  step += 1
@@ -125,23 +138,28 @@ def main() -> None:
125
  )
126
  response_text = completion.choices[0].message.content or ""
127
  except Exception as exc:
128
- # Print error to stderr only
129
- import sys
130
  print(f"Request failed: {exc}. Using fallback.", file=sys.stderr)
131
  response_text = FALLBACK_ACTION
132
 
133
  action = parse_model_action(response_text)
134
  obs, reward, done, info = env.step(action)
135
- final_reward = reward.value
 
136
 
137
- # STEP marker with required key=value format
138
- print(f"[STEP] step={step} reward={final_reward:.3f}", flush=True)
139
 
140
  history.append(f"Step {step}: {action.action_type}")
141
 
142
- # END marker
143
- print(f"[END] task={task} score={final_reward:.3f} steps={step}", flush=True)
 
 
 
 
 
 
144
 
 
145
 
146
  if __name__ == "__main__":
147
- main()
 
1
  import os
2
+ import sys
3
  import json
4
  import textwrap
5
+ import asyncio
6
  from typing import List
7
 
8
  from openai import OpenAI
 
16
  TEMPERATURE = 0.2
17
  MAX_TOKENS = 200
18
  FALLBACK_ACTION = "skip"
19
+ TASK_NAME = "code_review" # or "easy", "medium", etc. – but we'll vary per task
20
+ ENV_NAME = "code_review_env"
21
+ SUCCESS_THRESHOLD = 0.5 # if final score >= 0.5, consider success
22
 
23
  SYSTEM_PROMPT = textwrap.dedent(
24
  """
 
34
  """
35
  ).strip()
36
 
 
37
  def build_user_prompt(step: int, obs: Observation, history: List[str]) -> str:
38
  newline = "\n"
39
  comments_str = newline.join(obs.comments) if obs.comments else "No existing comments"
 
59
  ).strip()
60
  return prompt
61
 
 
62
  def parse_model_action(response_text: str) -> Action:
63
  if not response_text:
64
  return Action(action_type=FALLBACK_ACTION)
 
88
  return Action(action_type="propose_fix", fix_code=fix)
89
  return Action(action_type="write_comment", comment_text=raw)
90
 
91
+ def log_start(task: str, env: str, model: str):
92
+ print(f"[START] task={task} env={env} model={model}", flush=True)
93
+
94
+ def log_step(step: int, action: str, reward: float, done: bool, error: str = None):
95
+ error_str = f" error={error}" if error else ""
96
+ print(f"[STEP] step={step} action={action} reward={reward:.3f} done={done}{error_str}", flush=True)
97
 
98
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]):
99
+ # rewards as JSON array to match sample (though sample shows `rewards=rewards` which might be a list)
100
+ rewards_str = json.dumps(rewards, separators=(',', ':'))
101
+ print(f"[END] success={success} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
102
+
103
+ async def main() -> None:
104
  if not API_BASE_URL or not API_KEY or not MODEL_NAME:
 
 
 
105
  print("Error: API_BASE_URL, HF_TOKEN/API_KEY, and MODEL_NAME must be set.", file=sys.stderr)
106
  return
107
 
108
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
109
+ env = CodeReviewEnv() # synchronous environment – we'll wrap calls in async? The sample uses async, but we can run sync and just not await.
110
+ # Since our env is sync, we'll call methods directly. The logging is the key part.
111
+
112
  tasks = ["easy", "medium", "hard", "harder", "hardest"]
113
 
114
  for task in tasks:
 
117
  history: List[str] = []
118
  done = False
119
  step = 0
120
+ rewards = []
121
+ final_score = 0.0
122
 
123
+ log_start(task=task, env=ENV_NAME, model=MODEL_NAME)
 
124
 
125
  while not done and step < MAX_STEPS:
126
  step += 1
 
138
  )
139
  response_text = completion.choices[0].message.content or ""
140
  except Exception as exc:
 
 
141
  print(f"Request failed: {exc}. Using fallback.", file=sys.stderr)
142
  response_text = FALLBACK_ACTION
143
 
144
  action = parse_model_action(response_text)
145
  obs, reward, done, info = env.step(action)
146
+ reward_val = reward.value
147
+ rewards.append(reward_val)
148
 
149
+ log_step(step=step, action=action.action_type, reward=reward_val, done=done)
 
150
 
151
  history.append(f"Step {step}: {action.action_type}")
152
 
153
+ # Compute overall score (e.g., average reward or final reward? Sample uses sum(rewards)/MAX_TOTAL_REWARD.
154
+ # We'll use average reward per step as score, clamped to [0,1].
155
+ if rewards:
156
+ avg_reward = sum(rewards) / len(rewards)
157
+ score = max(0.0, min(1.0, avg_reward))
158
+ else:
159
+ score = 0.0
160
+ success = score >= SUCCESS_THRESHOLD
161
 
162
+ log_end(success=success, steps=step, score=score, rewards=rewards)
163
 
164
  if __name__ == "__main__":
165
+ asyncio.run(main())