100XZX001 commited on
Commit
114e5cf
·
verified ·
1 Parent(s): bc1b1a6

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +13 -38
inference.py CHANGED
@@ -1,8 +1,6 @@
1
  import os
2
  import sys
3
- import json
4
  import textwrap
5
- import asyncio
6
  from typing import List
7
 
8
  from openai import OpenAI
@@ -16,9 +14,6 @@ MAX_STEPS = 5
16
  TEMPERATURE = 0.2
17
  MAX_TOKENS = 200
18
  FALLBACK_ACTION = "skip"
19
- TASK_NAME = "code_review" # or "easy", "medium", etc. – but we'll vary per task
20
- ENV_NAME = "code_review_env"
21
- SUCCESS_THRESHOLD = 0.5 # if final score >= 0.5, consider success
22
 
23
  SYSTEM_PROMPT = textwrap.dedent(
24
  """
@@ -88,27 +83,14 @@ def parse_model_action(response_text: str) -> Action:
88
  return Action(action_type="propose_fix", fix_code=fix)
89
  return Action(action_type="write_comment", comment_text=raw)
90
 
91
- def log_start(task: str, env: str, model: str):
92
- print(f"[START] task={task} env={env} model={model}", flush=True)
93
-
94
- def log_step(step: int, action: str, reward: float, done: bool, error: str = None):
95
- error_str = f" error={error}" if error else ""
96
- print(f"[STEP] step={step} action={action} reward={reward:.3f} done={done}{error_str}", flush=True)
97
-
98
- def log_end(success: bool, steps: int, score: float, rewards: List[float]):
99
- # rewards as JSON array to match sample (though sample shows `rewards=rewards` which might be a list)
100
- rewards_str = json.dumps(rewards, separators=(',', ':'))
101
- print(f"[END] success={success} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
102
-
103
- async def main() -> None:
104
  if not API_BASE_URL or not API_KEY or not MODEL_NAME:
 
105
  print("Error: API_BASE_URL, HF_TOKEN/API_KEY, and MODEL_NAME must be set.", file=sys.stderr)
106
  return
107
 
108
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
109
- env = CodeReviewEnv() # synchronous environment – we'll wrap calls in async? The sample uses async, but we can run sync and just not await.
110
- # Since our env is sync, we'll call methods directly. The logging is the key part.
111
-
112
  tasks = ["easy", "medium", "hard", "harder", "hardest"]
113
 
114
  for task in tasks:
@@ -117,10 +99,10 @@ async def main() -> None:
117
  history: List[str] = []
118
  done = False
119
  step = 0
120
- rewards = []
121
- final_score = 0.0
122
 
123
- log_start(task=task, env=ENV_NAME, model=MODEL_NAME)
 
124
 
125
  while not done and step < MAX_STEPS:
126
  step += 1
@@ -138,28 +120,21 @@ async def main() -> None:
138
  )
139
  response_text = completion.choices[0].message.content or ""
140
  except Exception as exc:
 
141
  print(f"Request failed: {exc}. Using fallback.", file=sys.stderr)
142
  response_text = FALLBACK_ACTION
143
 
144
  action = parse_model_action(response_text)
145
  obs, reward, done, info = env.step(action)
146
- reward_val = reward.value
147
- rewards.append(reward_val)
148
 
149
- log_step(step=step, action=action.action_type, reward=reward_val, done=done)
 
150
 
151
  history.append(f"Step {step}: {action.action_type}")
152
 
153
- # Compute overall score (e.g., average reward or final reward? Sample uses sum(rewards)/MAX_TOTAL_REWARD.
154
- # We'll use average reward per step as score, clamped to [0,1].
155
- if rewards:
156
- avg_reward = sum(rewards) / len(rewards)
157
- score = max(0.0, min(1.0, avg_reward))
158
- else:
159
- score = 0.0
160
- success = score >= SUCCESS_THRESHOLD
161
-
162
- log_end(success=success, steps=step, score=score, rewards=rewards)
163
 
164
  if __name__ == "__main__":
165
- asyncio.run(main())
 
1
  import os
2
  import sys
 
3
  import textwrap
 
4
  from typing import List
5
 
6
  from openai import OpenAI
 
14
  TEMPERATURE = 0.2
15
  MAX_TOKENS = 200
16
  FALLBACK_ACTION = "skip"
 
 
 
17
 
18
  SYSTEM_PROMPT = textwrap.dedent(
19
  """
 
83
  return Action(action_type="propose_fix", fix_code=fix)
84
  return Action(action_type="write_comment", comment_text=raw)
85
 
86
+ def main() -> None:
 
 
 
 
 
 
 
 
 
 
 
 
87
  if not API_BASE_URL or not API_KEY or not MODEL_NAME:
88
+ # Print error to stderr (won't affect stdout)
89
  print("Error: API_BASE_URL, HF_TOKEN/API_KEY, and MODEL_NAME must be set.", file=sys.stderr)
90
  return
91
 
92
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
93
+ env = CodeReviewEnv()
 
 
94
  tasks = ["easy", "medium", "hard", "harder", "hardest"]
95
 
96
  for task in tasks:
 
99
  history: List[str] = []
100
  done = False
101
  step = 0
102
+ final_reward = 0.0
 
103
 
104
+ # START marker
105
+ print(f"[START] task={task}", flush=True)
106
 
107
  while not done and step < MAX_STEPS:
108
  step += 1
 
120
  )
121
  response_text = completion.choices[0].message.content or ""
122
  except Exception as exc:
123
+ # Print error to stderr only
124
  print(f"Request failed: {exc}. Using fallback.", file=sys.stderr)
125
  response_text = FALLBACK_ACTION
126
 
127
  action = parse_model_action(response_text)
128
  obs, reward, done, info = env.step(action)
129
+ final_reward = reward.value
 
130
 
131
+ # STEP marker
132
+ print(f"[STEP] step={step} reward={final_reward:.3f}", flush=True)
133
 
134
  history.append(f"Step {step}: {action.action_type}")
135
 
136
+ # END marker
137
+ print(f"[END] task={task} score={final_reward:.3f} steps={step}", flush=True)
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
+ main()