Spaces:
Sleeping
Sleeping
databoysu commited on
Commit ·
b65a477
1
Parent(s): f814100
improving submit
Browse files- inference.py +85 -63
inference.py
CHANGED
|
@@ -24,6 +24,7 @@ from pathlib import Path
|
|
| 24 |
from typing import Any, Optional
|
| 25 |
|
| 26 |
from openai import OpenAI
|
|
|
|
| 27 |
|
| 28 |
try:
|
| 29 |
from tracefix_rl import CodeAction, TraceFixRLEnv
|
|
@@ -36,7 +37,7 @@ except Exception:
|
|
| 36 |
|
| 37 |
|
| 38 |
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:1234/v1")
|
| 39 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "
|
| 40 |
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "lm-studio"
|
| 41 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 42 |
|
|
@@ -47,26 +48,52 @@ MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
|
| 47 |
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
|
| 48 |
|
| 49 |
SYSTEM_PROMPT = """\
|
| 50 |
-
You are a debugging policy agent.
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
Action policy:
|
| 61 |
-
- VIEW_CODE
|
| 62 |
-
- RUN_TESTS to
|
| 63 |
-
- REPLACE_LINES for
|
| 64 |
-
- UNDO_EDIT if
|
| 65 |
-
- RESET_TO_ORIGINAL as last-resort recovery.
|
| 66 |
-
- SUBMIT
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
"""
|
| 71 |
|
| 72 |
|
|
@@ -88,46 +115,15 @@ def _decode_action_json(raw_text: str) -> dict[str, Any]:
|
|
| 88 |
return json.loads(stripped)
|
| 89 |
|
| 90 |
|
| 91 |
-
def
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
if "action_type" not in normalized and isinstance(normalized.get("type"), str):
|
| 101 |
-
normalized["action_type"] = normalized["type"]
|
| 102 |
-
|
| 103 |
-
if "thought" not in normalized or normalized.get("thought") in (None, ""):
|
| 104 |
-
normalized["thought"] = (
|
| 105 |
-
"Recovered malformed action payload and mapped legacy fields "
|
| 106 |
-
"to strict CodeAction schema."
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
if "lines" in normalized and isinstance(normalized["lines"], list):
|
| 110 |
-
line_items = []
|
| 111 |
-
for item in normalized["lines"]:
|
| 112 |
-
if not isinstance(item, dict):
|
| 113 |
-
continue
|
| 114 |
-
line_no = item.get("line")
|
| 115 |
-
code_text = item.get("code")
|
| 116 |
-
if isinstance(line_no, int) and isinstance(code_text, str):
|
| 117 |
-
line_items.append((line_no, code_text))
|
| 118 |
-
if line_items:
|
| 119 |
-
line_items.sort(key=lambda x: x[0])
|
| 120 |
-
if "start_line" not in normalized:
|
| 121 |
-
normalized["start_line"] = line_items[0][0]
|
| 122 |
-
if "end_line" not in normalized:
|
| 123 |
-
normalized["end_line"] = line_items[-1][0]
|
| 124 |
-
if "new_code_block" not in normalized:
|
| 125 |
-
normalized["new_code_block"] = "\n".join(code for _, code in line_items)
|
| 126 |
-
|
| 127 |
-
normalized.pop("type", None)
|
| 128 |
-
normalized.pop("lines", None)
|
| 129 |
-
normalized.pop("source", None)
|
| 130 |
-
return normalized
|
| 131 |
|
| 132 |
|
| 133 |
def log_start(task: str, env: str, model: str) -> None:
|
|
@@ -199,7 +195,12 @@ def _get_model_action(
|
|
| 199 |
raw_response=raw_response,
|
| 200 |
)
|
| 201 |
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
assistant_json = action.model_dump_json(exclude_none=False)
|
| 204 |
return action, assistant_json
|
| 205 |
except Exception as parse_exc:
|
|
@@ -212,10 +213,14 @@ def _get_model_action(
|
|
| 212 |
)
|
| 213 |
raw_text = (completion.choices[0].message.content or "").strip()
|
| 214 |
parsed_dict = _decode_action_json(raw_text)
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
| 217 |
assistant_json = action.model_dump_json(exclude_none=False)
|
| 218 |
return action, assistant_json
|
|
|
|
|
|
|
| 219 |
except Exception as fallback_exc:
|
| 220 |
raise ModelParseError(
|
| 221 |
(
|
|
@@ -257,6 +262,7 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
|
|
| 257 |
kill_switch_triggered = False
|
| 258 |
last_action_type: Optional[str] = None
|
| 259 |
consecutive_same_action_count = 0
|
|
|
|
| 260 |
|
| 261 |
try:
|
| 262 |
if LOCAL_IMAGE_NAME:
|
|
@@ -293,7 +299,8 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
|
|
| 293 |
{
|
| 294 |
"role": "user",
|
| 295 |
"content": (
|
| 296 |
-
"Pick the single best next action and return only
|
|
|
|
| 297 |
f"action_trajectory={(' -> '.join(action_trajectory) if action_trajectory else 'none')}\n\n"
|
| 298 |
f"{obs_text}"
|
| 299 |
),
|
|
@@ -301,12 +308,14 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
|
|
| 301 |
)
|
| 302 |
try:
|
| 303 |
action, assistant_json = _get_model_action(client, history_messages)
|
|
|
|
| 304 |
history_messages.append({"role": "assistant", "content": assistant_json})
|
| 305 |
if show_thought:
|
| 306 |
_print_thought(action, assistant_json)
|
| 307 |
except ModelParseError as exc:
|
| 308 |
cause = str(exc).replace("\n", " ")
|
| 309 |
parse_error_note = cause
|
|
|
|
| 310 |
raw_response = (exc.raw_response or "").strip()
|
| 311 |
if raw_response:
|
| 312 |
history_messages.append({"role": "assistant", "content": raw_response})
|
|
@@ -321,6 +330,16 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
|
|
| 321 |
}
|
| 322 |
)
|
| 323 |
history.append(f"PARSE_ERROR: {cause}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
action = CodeAction(
|
| 325 |
action_type="RUN_TESTS",
|
| 326 |
thought=(
|
|
@@ -329,6 +348,9 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
|
|
| 329 |
),
|
| 330 |
)
|
| 331 |
|
|
|
|
|
|
|
|
|
|
| 332 |
current_action_type = action.action_type
|
| 333 |
if current_action_type == last_action_type:
|
| 334 |
consecutive_same_action_count += 1
|
|
|
|
| 24 |
from typing import Any, Optional
|
| 25 |
|
| 26 |
from openai import OpenAI
|
| 27 |
+
from pydantic import ValidationError
|
| 28 |
|
| 29 |
try:
|
| 30 |
from tracefix_rl import CodeAction, TraceFixRLEnv
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:1234/v1")
|
| 40 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
|
| 41 |
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "lm-studio"
|
| 42 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 43 |
|
|
|
|
| 48 |
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
|
| 49 |
|
| 50 |
SYSTEM_PROMPT = """\
|
| 51 |
+
You are a deterministic debugging policy agent.
|
| 52 |
+
You must output exactly one valid CodeAction JSON object per turn and nothing else.
|
| 53 |
+
|
| 54 |
+
Primary failures to avoid:
|
| 55 |
+
1) Invalid JSON or wrong field types.
|
| 56 |
+
2) Misreading last_execution_output and submitting before tests are truly passing.
|
| 57 |
+
|
| 58 |
+
Output contract (strict):
|
| 59 |
+
- Return a single JSON object, not an array.
|
| 60 |
+
- Allowed keys only: thought, action_type, start_line, end_line, new_code_block.
|
| 61 |
+
- No markdown, no code fences, no commentary outside JSON, no extra keys.
|
| 62 |
+
- thought must be a plain string.
|
| 63 |
+
- action_type must be one of: VIEW_CODE, RUN_TESTS, REPLACE_LINES, UNDO_EDIT, RESET_TO_ORIGINAL, SUBMIT.
|
| 64 |
+
- start_line and end_line must be integer or null.
|
| 65 |
+
- new_code_block must be string or null.
|
| 66 |
+
- If action_type is not REPLACE_LINES, set start_line=null, end_line=null, new_code_block=null.
|
| 67 |
+
- If action_type is REPLACE_LINES, set start_line and end_line to exact integer keys from code_dict and provide new_code_block as replacement code only.
|
| 68 |
+
|
| 69 |
+
Mandatory thought format:
|
| 70 |
+
Observation: summarize concrete evidence from localized_context and/or last_execution_output.
|
| 71 |
+
Diagnosis: identify the most likely root cause and exact line numbers to edit when applicable.
|
| 72 |
+
Plan: choose the next action_type and justify it briefly.
|
| 73 |
+
|
| 74 |
+
How to read last_execution_output correctly:
|
| 75 |
+
- Prefer traceback and assertion text over assumptions.
|
| 76 |
+
- Extract failing test name, exception type, file path, and line number when present.
|
| 77 |
+
- If output is truncated or ambiguous, run RUN_TESTS before editing.
|
| 78 |
+
- Treat syntax errors as highest priority and fix them before semantic issues.
|
| 79 |
+
- Never claim success unless output clearly indicates complete pass status.
|
| 80 |
|
| 81 |
Action policy:
|
| 82 |
+
- VIEW_CODE when line mapping or surrounding context is insufficient.
|
| 83 |
+
- RUN_TESTS to collect fresh evidence after edits or when uncertain.
|
| 84 |
+
- REPLACE_LINES for minimal, line-accurate fixes tied to observed failures.
|
| 85 |
+
- UNDO_EDIT if latest change worsened results or introduced new failures.
|
| 86 |
+
- RESET_TO_ORIGINAL only as last-resort recovery.
|
| 87 |
+
- SUBMIT only when last_execution_output explicitly and unambiguously indicates all tests passed.
|
| 88 |
+
|
| 89 |
+
Submit gate (hard rule):
|
| 90 |
+
- If any failure, error, traceback, xfailed/unfinished signal, or uncertainty remains, do not SUBMIT.
|
| 91 |
+
|
| 92 |
+
Self-check before finalizing response:
|
| 93 |
+
- Is this valid JSON?
|
| 94 |
+
- Are all values schema-valid primitive types?
|
| 95 |
+
- Are nulls set correctly for non-REPLACE_LINES actions?
|
| 96 |
+
- Does the thought have exactly 3 sentences in the required Observation/Diagnosis/Plan structure?
|
| 97 |
"""
|
| 98 |
|
| 99 |
|
|
|
|
| 115 |
return json.loads(stripped)
|
| 116 |
|
| 117 |
|
| 118 |
+
def _clean_validation_error(exc: ValidationError) -> str:
|
| 119 |
+
"""Return a concise, user-facing schema violation summary."""
|
| 120 |
+
first_error = exc.errors()[0] if exc.errors() else {}
|
| 121 |
+
loc = first_error.get("loc", ["Unknown"])
|
| 122 |
+
field_name = loc[0] if isinstance(loc, (list, tuple)) and loc else "Unknown"
|
| 123 |
+
return (
|
| 124 |
+
f"JSON Schema Violation on field '{field_name}': Must be a flat string/integer. "
|
| 125 |
+
"Do not use nested objects or arrays."
|
| 126 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def log_start(task: str, env: str, model: str) -> None:
|
|
|
|
| 195 |
raw_response=raw_response,
|
| 196 |
)
|
| 197 |
|
| 198 |
+
try:
|
| 199 |
+
action = CodeAction.model_validate(parsed)
|
| 200 |
+
except ValidationError as exc:
|
| 201 |
+
content = getattr(message, "content", "")
|
| 202 |
+
raw_response = content if isinstance(content, str) else json.dumps(content, ensure_ascii=True, default=str)
|
| 203 |
+
raise ModelParseError(_clean_validation_error(exc), raw_response=raw_response) from exc
|
| 204 |
assistant_json = action.model_dump_json(exclude_none=False)
|
| 205 |
return action, assistant_json
|
| 206 |
except Exception as parse_exc:
|
|
|
|
| 213 |
)
|
| 214 |
raw_text = (completion.choices[0].message.content or "").strip()
|
| 215 |
parsed_dict = _decode_action_json(raw_text)
|
| 216 |
+
try:
|
| 217 |
+
action = CodeAction.model_validate(parsed_dict)
|
| 218 |
+
except ValidationError as exc:
|
| 219 |
+
raise ModelParseError(_clean_validation_error(exc), raw_response=raw_text) from exc
|
| 220 |
assistant_json = action.model_dump_json(exclude_none=False)
|
| 221 |
return action, assistant_json
|
| 222 |
+
except ModelParseError:
|
| 223 |
+
raise
|
| 224 |
except Exception as fallback_exc:
|
| 225 |
raise ModelParseError(
|
| 226 |
(
|
|
|
|
| 262 |
kill_switch_triggered = False
|
| 263 |
last_action_type: Optional[str] = None
|
| 264 |
consecutive_same_action_count = 0
|
| 265 |
+
consecutive_parse_error_count = 0
|
| 266 |
|
| 267 |
try:
|
| 268 |
if LOCAL_IMAGE_NAME:
|
|
|
|
| 299 |
{
|
| 300 |
"role": "user",
|
| 301 |
"content": (
|
| 302 |
+
"Pick the single best next action and return only one valid CodeAction JSON object. "
|
| 303 |
+
"Use localized_context/last_execution_output as evidence, and do not SUBMIT unless all tests explicitly pass.\n\n"
|
| 304 |
f"action_trajectory={(' -> '.join(action_trajectory) if action_trajectory else 'none')}\n\n"
|
| 305 |
f"{obs_text}"
|
| 306 |
),
|
|
|
|
| 308 |
)
|
| 309 |
try:
|
| 310 |
action, assistant_json = _get_model_action(client, history_messages)
|
| 311 |
+
consecutive_parse_error_count = 0
|
| 312 |
history_messages.append({"role": "assistant", "content": assistant_json})
|
| 313 |
if show_thought:
|
| 314 |
_print_thought(action, assistant_json)
|
| 315 |
except ModelParseError as exc:
|
| 316 |
cause = str(exc).replace("\n", " ")
|
| 317 |
parse_error_note = cause
|
| 318 |
+
consecutive_parse_error_count += 1
|
| 319 |
raw_response = (exc.raw_response or "").strip()
|
| 320 |
if raw_response:
|
| 321 |
history_messages.append({"role": "assistant", "content": raw_response})
|
|
|
|
| 330 |
}
|
| 331 |
)
|
| 332 |
history.append(f"PARSE_ERROR: {cause}")
|
| 333 |
+
if consecutive_parse_error_count >= 3:
|
| 334 |
+
kill_switch_triggered = True
|
| 335 |
+
history.append(
|
| 336 |
+
"KILL_SWITCH: PARSE_ERROR occurred 3 times consecutively. "
|
| 337 |
+
"Terminating episode early to prevent token burn."
|
| 338 |
+
)
|
| 339 |
+
steps_taken = step
|
| 340 |
+
success = False
|
| 341 |
+
score = 0.0
|
| 342 |
+
break
|
| 343 |
action = CodeAction(
|
| 344 |
action_type="RUN_TESTS",
|
| 345 |
thought=(
|
|
|
|
| 348 |
),
|
| 349 |
)
|
| 350 |
|
| 351 |
+
if kill_switch_triggered:
|
| 352 |
+
break
|
| 353 |
+
|
| 354 |
current_action_type = action.action_type
|
| 355 |
if current_action_type == last_action_type:
|
| 356 |
consecutive_same_action_count += 1
|