databoysu commited on
Commit
b65a477
·
1 Parent(s): f814100

improving submit

Browse files
Files changed (1) hide show
  1. inference.py +85 -63
inference.py CHANGED
@@ -24,6 +24,7 @@ from pathlib import Path
24
  from typing import Any, Optional
25
 
26
  from openai import OpenAI
 
27
 
28
  try:
29
  from tracefix_rl import CodeAction, TraceFixRLEnv
@@ -36,7 +37,7 @@ except Exception:
36
 
37
 
38
  API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:1234/v1")
39
- MODEL_NAME = os.getenv("MODEL_NAME", "nvidia/nemotron-3-nano-4b")
40
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "lm-studio"
41
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
42
 
@@ -47,26 +48,52 @@ MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
47
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
48
 
49
  SYSTEM_PROMPT = """\
50
- You are a debugging policy agent. Output exactly one CodeAction JSON object per turn.
51
-
52
- Use Action Trajectory on every turn. If an action repeats without progress, change strategy.
53
- PARSE_ERROR means your previous output was invalid; fix formatting immediately.
54
-
55
- Mandatory thought format (exactly 3 sentences):
56
- Observation: what you see in localized_context or last_execution_output.
57
- Diagnosis: root cause and exact line(s) to change.
58
- Plan: the next action_type and why.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  Action policy:
61
- - VIEW_CODE to inspect full line mapping.
62
- - RUN_TESTS to get fresh traceback evidence.
63
- - REPLACE_LINES for focused fixes using exact code_dict keys.
64
- - UNDO_EDIT if the latest edit made things worse.
65
- - RESET_TO_ORIGINAL as last-resort recovery.
66
- - SUBMIT ONLY when last_execution_output explicitly contains the success signal that all tests passed.
67
-
68
- Return only JSON keys: thought, action_type, start_line, end_line, new_code_block.
69
- No markdown. No extra keys.
 
 
 
 
 
 
70
  """
71
 
72
 
@@ -88,46 +115,15 @@ def _decode_action_json(raw_text: str) -> dict[str, Any]:
88
  return json.loads(stripped)
89
 
90
 
91
- def _coerce_legacy_action_payload(payload: dict[str, Any]) -> dict[str, Any]:
92
- """
93
- Normalize common legacy output shapes into strict CodeAction fields.
94
-
95
- This keeps runtime resilient across weaker model backends while still
96
- validating the final payload with strict Pydantic rules.
97
- """
98
- normalized = dict(payload)
99
-
100
- if "action_type" not in normalized and isinstance(normalized.get("type"), str):
101
- normalized["action_type"] = normalized["type"]
102
-
103
- if "thought" not in normalized or normalized.get("thought") in (None, ""):
104
- normalized["thought"] = (
105
- "Recovered malformed action payload and mapped legacy fields "
106
- "to strict CodeAction schema."
107
- )
108
-
109
- if "lines" in normalized and isinstance(normalized["lines"], list):
110
- line_items = []
111
- for item in normalized["lines"]:
112
- if not isinstance(item, dict):
113
- continue
114
- line_no = item.get("line")
115
- code_text = item.get("code")
116
- if isinstance(line_no, int) and isinstance(code_text, str):
117
- line_items.append((line_no, code_text))
118
- if line_items:
119
- line_items.sort(key=lambda x: x[0])
120
- if "start_line" not in normalized:
121
- normalized["start_line"] = line_items[0][0]
122
- if "end_line" not in normalized:
123
- normalized["end_line"] = line_items[-1][0]
124
- if "new_code_block" not in normalized:
125
- normalized["new_code_block"] = "\n".join(code for _, code in line_items)
126
-
127
- normalized.pop("type", None)
128
- normalized.pop("lines", None)
129
- normalized.pop("source", None)
130
- return normalized
131
 
132
 
133
  def log_start(task: str, env: str, model: str) -> None:
@@ -199,7 +195,12 @@ def _get_model_action(
199
  raw_response=raw_response,
200
  )
201
 
202
- action = CodeAction.model_validate(parsed)
 
 
 
 
 
203
  assistant_json = action.model_dump_json(exclude_none=False)
204
  return action, assistant_json
205
  except Exception as parse_exc:
@@ -212,10 +213,14 @@ def _get_model_action(
212
  )
213
  raw_text = (completion.choices[0].message.content or "").strip()
214
  parsed_dict = _decode_action_json(raw_text)
215
- parsed_dict = _coerce_legacy_action_payload(parsed_dict)
216
- action = CodeAction.model_validate(parsed_dict)
 
 
217
  assistant_json = action.model_dump_json(exclude_none=False)
218
  return action, assistant_json
 
 
219
  except Exception as fallback_exc:
220
  raise ModelParseError(
221
  (
@@ -257,6 +262,7 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
257
  kill_switch_triggered = False
258
  last_action_type: Optional[str] = None
259
  consecutive_same_action_count = 0
 
260
 
261
  try:
262
  if LOCAL_IMAGE_NAME:
@@ -293,7 +299,8 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
293
  {
294
  "role": "user",
295
  "content": (
296
- "Pick the single best next action and return only a CodeAction JSON object.\n\n"
 
297
  f"action_trajectory={(' -> '.join(action_trajectory) if action_trajectory else 'none')}\n\n"
298
  f"{obs_text}"
299
  ),
@@ -301,12 +308,14 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
301
  )
302
  try:
303
  action, assistant_json = _get_model_action(client, history_messages)
 
304
  history_messages.append({"role": "assistant", "content": assistant_json})
305
  if show_thought:
306
  _print_thought(action, assistant_json)
307
  except ModelParseError as exc:
308
  cause = str(exc).replace("\n", " ")
309
  parse_error_note = cause
 
310
  raw_response = (exc.raw_response or "").strip()
311
  if raw_response:
312
  history_messages.append({"role": "assistant", "content": raw_response})
@@ -321,6 +330,16 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
321
  }
322
  )
323
  history.append(f"PARSE_ERROR: {cause}")
 
 
 
 
 
 
 
 
 
 
324
  action = CodeAction(
325
  action_type="RUN_TESTS",
326
  thought=(
@@ -329,6 +348,9 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
329
  ),
330
  )
331
 
 
 
 
332
  current_action_type = action.action_type
333
  if current_action_type == last_action_type:
334
  consecutive_same_action_count += 1
 
24
  from typing import Any, Optional
25
 
26
  from openai import OpenAI
27
+ from pydantic import ValidationError
28
 
29
  try:
30
  from tracefix_rl import CodeAction, TraceFixRLEnv
 
37
 
38
 
39
  API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:1234/v1")
40
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
41
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "lm-studio"
42
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
43
 
 
48
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
49
 
50
  SYSTEM_PROMPT = """\
51
+ You are a deterministic debugging policy agent.
52
+ You must output exactly one valid CodeAction JSON object per turn and nothing else.
53
+
54
+ Primary failures to avoid:
55
+ 1) Invalid JSON or wrong field types.
56
+ 2) Misreading last_execution_output and submitting before tests are truly passing.
57
+
58
+ Output contract (strict):
59
+ - Return a single JSON object, not an array.
60
+ - Allowed keys only: thought, action_type, start_line, end_line, new_code_block.
61
+ - No markdown, no code fences, no commentary outside JSON, no extra keys.
62
+ - thought must be a plain string.
63
+ - action_type must be one of: VIEW_CODE, RUN_TESTS, REPLACE_LINES, UNDO_EDIT, RESET_TO_ORIGINAL, SUBMIT.
64
+ - start_line and end_line must be integer or null.
65
+ - new_code_block must be string or null.
66
+ - If action_type is not REPLACE_LINES, set start_line=null, end_line=null, new_code_block=null.
67
+ - If action_type is REPLACE_LINES, set start_line and end_line to exact integer keys from code_dict and provide new_code_block as replacement code only.
68
+
69
+ Mandatory thought format:
70
+ Observation: summarize concrete evidence from localized_context and/or last_execution_output.
71
+ Diagnosis: identify the most likely root cause and exact line numbers to edit when applicable.
72
+ Plan: choose the next action_type and justify it briefly.
73
+
74
+ How to read last_execution_output correctly:
75
+ - Prefer traceback and assertion text over assumptions.
76
+ - Extract failing test name, exception type, file path, and line number when present.
77
+ - If output is truncated or ambiguous, run RUN_TESTS before editing.
78
+ - Treat syntax errors as highest priority and fix them before semantic issues.
79
+ - Never claim success unless output clearly indicates complete pass status.
80
 
81
  Action policy:
82
+ - VIEW_CODE when line mapping or surrounding context is insufficient.
83
+ - RUN_TESTS to collect fresh evidence after edits or when uncertain.
84
+ - REPLACE_LINES for minimal, line-accurate fixes tied to observed failures.
85
+ - UNDO_EDIT if latest change worsened results or introduced new failures.
86
+ - RESET_TO_ORIGINAL only as last-resort recovery.
87
+ - SUBMIT only when last_execution_output explicitly and unambiguously indicates all tests passed.
88
+
89
+ Submit gate (hard rule):
90
+ - If any failure, error, traceback, xfailed/unfinished signal, or uncertainty remains, do not SUBMIT.
91
+
92
+ Self-check before finalizing response:
93
+ - Is this valid JSON?
94
+ - Are all values schema-valid primitive types?
95
+ - Are nulls set correctly for non-REPLACE_LINES actions?
96
+ - Does the thought have exactly 3 sentences in the required Observation/Diagnosis/Plan structure?
97
  """
98
 
99
 
 
115
  return json.loads(stripped)
116
 
117
 
118
+ def _clean_validation_error(exc: ValidationError) -> str:
119
+ """Return a concise, user-facing schema violation summary."""
120
+ first_error = exc.errors()[0] if exc.errors() else {}
121
+ loc = first_error.get("loc", ["Unknown"])
122
+ field_name = loc[0] if isinstance(loc, (list, tuple)) and loc else "Unknown"
123
+ return (
124
+ f"JSON Schema Violation on field '{field_name}': Must be a flat string/integer. "
125
+ "Do not use nested objects or arrays."
126
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
 
129
  def log_start(task: str, env: str, model: str) -> None:
 
195
  raw_response=raw_response,
196
  )
197
 
198
+ try:
199
+ action = CodeAction.model_validate(parsed)
200
+ except ValidationError as exc:
201
+ content = getattr(message, "content", "")
202
+ raw_response = content if isinstance(content, str) else json.dumps(content, ensure_ascii=True, default=str)
203
+ raise ModelParseError(_clean_validation_error(exc), raw_response=raw_response) from exc
204
  assistant_json = action.model_dump_json(exclude_none=False)
205
  return action, assistant_json
206
  except Exception as parse_exc:
 
213
  )
214
  raw_text = (completion.choices[0].message.content or "").strip()
215
  parsed_dict = _decode_action_json(raw_text)
216
+ try:
217
+ action = CodeAction.model_validate(parsed_dict)
218
+ except ValidationError as exc:
219
+ raise ModelParseError(_clean_validation_error(exc), raw_response=raw_text) from exc
220
  assistant_json = action.model_dump_json(exclude_none=False)
221
  return action, assistant_json
222
+ except ModelParseError:
223
+ raise
224
  except Exception as fallback_exc:
225
  raise ModelParseError(
226
  (
 
262
  kill_switch_triggered = False
263
  last_action_type: Optional[str] = None
264
  consecutive_same_action_count = 0
265
+ consecutive_parse_error_count = 0
266
 
267
  try:
268
  if LOCAL_IMAGE_NAME:
 
299
  {
300
  "role": "user",
301
  "content": (
302
+ "Pick the single best next action and return only one valid CodeAction JSON object. "
303
+ "Use localized_context/last_execution_output as evidence, and do not SUBMIT unless all tests explicitly pass.\n\n"
304
  f"action_trajectory={(' -> '.join(action_trajectory) if action_trajectory else 'none')}\n\n"
305
  f"{obs_text}"
306
  ),
 
308
  )
309
  try:
310
  action, assistant_json = _get_model_action(client, history_messages)
311
+ consecutive_parse_error_count = 0
312
  history_messages.append({"role": "assistant", "content": assistant_json})
313
  if show_thought:
314
  _print_thought(action, assistant_json)
315
  except ModelParseError as exc:
316
  cause = str(exc).replace("\n", " ")
317
  parse_error_note = cause
318
+ consecutive_parse_error_count += 1
319
  raw_response = (exc.raw_response or "").strip()
320
  if raw_response:
321
  history_messages.append({"role": "assistant", "content": raw_response})
 
330
  }
331
  )
332
  history.append(f"PARSE_ERROR: {cause}")
333
+ if consecutive_parse_error_count >= 3:
334
+ kill_switch_triggered = True
335
+ history.append(
336
+ "KILL_SWITCH: PARSE_ERROR occurred 3 times consecutively. "
337
+ "Terminating episode early to prevent token burn."
338
+ )
339
+ steps_taken = step
340
+ success = False
341
+ score = 0.0
342
+ break
343
  action = CodeAction(
344
  action_type="RUN_TESTS",
345
  thought=(
 
348
  ),
349
  )
350
 
351
+ if kill_switch_triggered:
352
+ break
353
+
354
  current_action_type = action.action_type
355
  if current_action_type == last_action_type:
356
  consecutive_same_action_count += 1