databoysu commited on
Commit
a1e4e94
·
1 Parent(s): b4f37fd

fix reward hacking

Browse files
Files changed (1) hide show
  1. inference.py +98 -38
inference.py CHANGED
@@ -46,7 +46,7 @@ TASK_NAME = os.getenv("TASK_NAME", "tracefix_rl")
46
  BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
47
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
48
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
49
- THINKING_TOKEN_LIMIT = int(os.getenv("THINKING_TOKEN_LIMIT", "512"))
50
  MAX_PARSE_RETRIES = 3
51
 
52
  SYSTEM_PROMPT = (
@@ -54,9 +54,40 @@ SYSTEM_PROMPT = (
54
  "Return only JSON for one action.\n"
55
  'Allowed action_type values: VIEW_CODE, RUN_TESTS, REPLACE_LINES, UNDO_EDIT, RESET_TO_ORIGINAL, SUBMIT.\n'
56
  "For REPLACE_LINES include start_line, end_line, new_code_block.\n"
 
57
  "Prefer RUN_TESTS after edits and SUBMIT only when all tests pass."
58
  )
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def log_start(task: str, env: str, model: str) -> None:
62
  print(f"[START] task={task} env={env} model={model}", flush=True)
@@ -114,26 +145,36 @@ def _build_observation_text(observation: Any) -> str:
114
  )
115
 
116
 
117
- def _get_model_action(
118
  client: OpenAI, observation: Any, history: list[str]
119
- ) -> tuple[dict[str, Any], str]:
120
  obs_text = _build_observation_text(observation)
121
  user_prompt = (
122
  "Pick the single best next action and return only JSON.\n\n"
123
  f"{obs_text}\n\n"
124
  f"history:\n{chr(10).join(history[-5:]) if history else 'none'}"
125
  )
126
- completion = client.chat.completions.create(
127
- model=MODEL_NAME,
128
- messages=[
129
  {"role": "system", "content": SYSTEM_PROMPT},
130
  {"role": "user", "content": user_prompt},
131
  ],
132
- temperature=0.0,
133
- max_tokens=THINKING_TOKEN_LIMIT,
134
- stream=False,
135
- )
136
- response_text = (completion.choices[0].message.content or "").strip()
 
 
 
 
 
 
 
 
 
 
137
  action = _extract_json(response_text)
138
 
139
  if action.get("action_type") not in {
@@ -146,7 +187,7 @@ def _get_model_action(
146
  }:
147
  raise ValueError("Invalid action_type in model response.")
148
 
149
- return action, response_text
150
 
151
 
152
  def _to_code_action(action_dict: dict[str, Any]) -> CodeAction:
@@ -160,6 +201,13 @@ def _to_code_action(action_dict: dict[str, Any]) -> CodeAction:
160
  return CodeAction(**payload)
161
 
162
 
 
 
 
 
 
 
 
163
  def _compute_score(step_result: Any, rewards: list[float]) -> float:
164
  meta = step_result.observation.metadata or {}
165
  raw = meta.get("final_score")
@@ -203,31 +251,40 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
203
 
204
  action: Optional[CodeAction] = None
205
  model_response = ""
206
-
207
- for attempt in range(1, MAX_PARSE_RETRIES + 1):
208
- try:
209
- action_dict, model_response = _get_model_action(client, result.observation, history)
210
- action = _to_code_action(action_dict)
211
- if show_thought:
212
- history.append(f"thought={action.thought}")
213
- break
214
- except Exception as exc:
215
- cause = str(exc).replace("\n", " ")
216
- history.append(
217
- (
218
- f"parse_failure attempt={attempt} cause={cause}. "
219
- "Error: Invalid JSON or schema. Return a complete valid JSON object "
220
- "with fields: thought, action_type, start_line, end_line, new_code_block."
221
- )
222
- )
223
- if model_response:
224
- history.append(f"raw_response={model_response[:500]}")
225
-
226
- if action is None:
227
  action = CodeAction(
228
- action_type="RUN_TESTS",
229
- thought="Fallback after repeated invalid JSON/schema responses.",
230
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  result = await env.step(action)
233
 
@@ -242,7 +299,11 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
242
 
243
  rewards.append(reward)
244
  steps_taken = step
245
- history.append(f"step={step} action={action_str} reward={reward:.2f} error={error or 'null'}")
 
 
 
 
246
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
247
 
248
  if done:
@@ -255,7 +316,6 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
255
  if not started:
256
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
257
  started = True
258
- msg = str(exc).replace("\n", " ")
259
  score = 0.0
260
  success = False
261
  finally:
@@ -273,7 +333,7 @@ if __name__ == "__main__":
273
  group.add_argument("--easy", action="store_true", help="Run on easy curriculum tier.")
274
  group.add_argument("--medium", action="store_true", help="Run on medium curriculum tier.")
275
  group.add_argument("--hard", action="store_true", help="Run on hard curriculum tier.")
276
- parser.add_argument("--thought", action="store_true", help="Include model thought traces in internal history.")
277
  args = parser.parse_args()
278
 
279
  difficulty: Optional[str] = None
 
46
  BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
47
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
48
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
49
+ THINKING_TOKEN_LIMIT = int(os.getenv("THINKING_TOKEN_LIMIT", "1000"))
50
  MAX_PARSE_RETRIES = 3
51
 
52
  SYSTEM_PROMPT = (
 
54
  "Return only JSON for one action.\n"
55
  'Allowed action_type values: VIEW_CODE, RUN_TESTS, REPLACE_LINES, UNDO_EDIT, RESET_TO_ORIGINAL, SUBMIT.\n'
56
  "For REPLACE_LINES include start_line, end_line, new_code_block.\n"
57
+ "If available, include a 'thought' field explaining what you observed and why this is the next best action.\n"
58
  "Prefer RUN_TESTS after edits and SUBMIT only when all tests pass."
59
  )
60
 
61
+ ACTION_JSON_SCHEMA: dict[str, Any] = {
62
+ "type": "json_schema",
63
+ "json_schema": {
64
+ "name": "CodeAction",
65
+ "strict": True,
66
+ "schema": {
67
+ "type": "object",
68
+ "properties": {
69
+ "thought": {"type": ["string", "null"]},
70
+ "action_type": {
71
+ "type": "string",
72
+ "enum": [
73
+ "VIEW_CODE",
74
+ "RUN_TESTS",
75
+ "REPLACE_LINES",
76
+ "UNDO_EDIT",
77
+ "RESET_TO_ORIGINAL",
78
+ "SUBMIT",
79
+ ],
80
+ },
81
+ "start_line": {"type": ["integer", "null"]},
82
+ "end_line": {"type": ["integer", "null"]},
83
+ "new_code_block": {"type": ["string", "null"]},
84
+ },
85
+ "required": ["action_type"],
86
+ "additionalProperties": False,
87
+ },
88
+ },
89
+ }
90
+
91
 
92
  def log_start(task: str, env: str, model: str) -> None:
93
  print(f"[START] task={task} env={env} model={model}", flush=True)
 
145
  )
146
 
147
 
148
+ def _get_model_response(
149
  client: OpenAI, observation: Any, history: list[str]
150
+ ) -> str:
151
  obs_text = _build_observation_text(observation)
152
  user_prompt = (
153
  "Pick the single best next action and return only JSON.\n\n"
154
  f"{obs_text}\n\n"
155
  f"history:\n{chr(10).join(history[-5:]) if history else 'none'}"
156
  )
157
+ request_kwargs = {
158
+ "model": MODEL_NAME,
159
+ "messages": [
160
  {"role": "system", "content": SYSTEM_PROMPT},
161
  {"role": "user", "content": user_prompt},
162
  ],
163
+ "temperature": 0.0,
164
+ "max_tokens": THINKING_TOKEN_LIMIT,
165
+ "stream": False,
166
+ }
167
+ try:
168
+ completion = client.chat.completions.create(
169
+ **request_kwargs,
170
+ response_format=ACTION_JSON_SCHEMA,
171
+ )
172
+ except Exception:
173
+ completion = client.chat.completions.create(**request_kwargs)
174
+ return (completion.choices[0].message.content or "").strip()
175
+
176
+
177
+ def _parse_model_action(response_text: str) -> dict[str, Any]:
178
  action = _extract_json(response_text)
179
 
180
  if action.get("action_type") not in {
 
187
  }:
188
  raise ValueError("Invalid action_type in model response.")
189
 
190
+ return action
191
 
192
 
193
  def _to_code_action(action_dict: dict[str, Any]) -> CodeAction:
 
201
  return CodeAction(**payload)
202
 
203
 
204
+ def _print_thought(action_dict: dict[str, Any], raw_response: str) -> None:
205
+ thought = action_dict.get("thought")
206
+ thought_text = thought.strip() if isinstance(thought, str) else ""
207
+ print("[THOUGHT]", file=sys.stderr, flush=True)
208
+ print(thought_text if thought_text else raw_response, file=sys.stderr, flush=True)
209
+
210
+
211
  def _compute_score(step_result: Any, rewards: list[float]) -> float:
212
  meta = step_result.observation.metadata or {}
213
  raw = meta.get("final_score")
 
251
 
252
  action: Optional[CodeAction] = None
253
  model_response = ""
254
+ if step == 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  action = CodeAction(
256
+ action_type="VIEW_CODE",
257
+ thought="First step policy: inspect source before testing or editing.",
258
  )
259
+ if show_thought:
260
+ print("[THOUGHT]", file=sys.stderr, flush=True)
261
+ print(action.thought, file=sys.stderr, flush=True)
262
+ else:
263
+ for attempt in range(1, MAX_PARSE_RETRIES + 1):
264
+ try:
265
+ model_response = _get_model_response(client, result.observation, history)
266
+ action_dict = _parse_model_action(model_response)
267
+ if show_thought:
268
+ _print_thought(action_dict, model_response)
269
+ action = _to_code_action(action_dict)
270
+ break
271
+ except Exception as exc:
272
+ cause = str(exc).replace("\n", " ")
273
+ history.append(
274
+ (
275
+ f"parse_failure attempt={attempt} cause={cause}. "
276
+ "Error: Invalid JSON or schema. Return a complete valid JSON object "
277
+ "with fields: thought, action_type, start_line, end_line, new_code_block."
278
+ )
279
+ )
280
+ if model_response:
281
+ history.append(f"raw_response={model_response[:500]}")
282
+
283
+ if action is None:
284
+ action = CodeAction(
285
+ action_type="RUN_TESTS",
286
+ thought="Fallback after repeated invalid JSON/schema responses.",
287
+ )
288
 
289
  result = await env.step(action)
290
 
 
299
 
300
  rewards.append(reward)
301
  steps_taken = step
302
+ action_thought = (action.thought or "").strip()
303
+ history.append(
304
+ f"Action {action_str}; reward {reward:.2f}; error {error or 'null'}."
305
+ + (f" Thought: {action_thought}" if action_thought else "")
306
+ )
307
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
308
 
309
  if done:
 
316
  if not started:
317
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
318
  started = True
 
319
  score = 0.0
320
  success = False
321
  finally:
 
333
  group.add_argument("--easy", action="store_true", help="Run on easy curriculum tier.")
334
  group.add_argument("--medium", action="store_true", help="Run on medium curriculum tier.")
335
  group.add_argument("--hard", action="store_true", help="Run on hard curriculum tier.")
336
+ parser.add_argument("--thought", action="store_true", help="Print LLM thought trace to stderr only.")
337
  args = parser.parse_args()
338
 
339
  difficulty: Optional[str] = None