databoysu commited on
Commit
9a026d7
·
1 Parent(s): f341e79

Thought enabling

Browse files
Files changed (3) hide show
  1. __pycache__/models.cpython-312.pyc +0 -0
  2. inference.py +7 -21
  3. models.py +34 -56
__pycache__/models.cpython-312.pyc CHANGED
Binary files a/__pycache__/models.cpython-312.pyc and b/__pycache__/models.cpython-312.pyc differ
 
inference.py CHANGED
@@ -19,7 +19,6 @@ import argparse
19
  import asyncio
20
  import json
21
  import os
22
- import re
23
  import sys
24
  from pathlib import Path
25
  from typing import Any, Optional
@@ -46,7 +45,6 @@ TASK_NAME = os.getenv("TASK_NAME", "tracefix_rl")
46
  BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
47
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
48
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
49
- THINKING_TOKEN_LIMIT = int(os.getenv("THINKING_TOKEN_LIMIT", "1000"))
50
  MAX_PARSE_RETRIES = 3
51
 
52
  SYSTEM_PROMPT = (
@@ -111,26 +109,15 @@ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> No
111
 
112
  def _extract_json(text: str) -> dict[str, Any]:
113
  stripped = text.strip()
 
 
 
 
 
114
  try:
115
  return json.loads(stripped)
116
- except json.JSONDecodeError:
117
- pass
118
-
119
- fence = re.search(r"```(?:json)?\s*({.*?})\s*```", stripped, re.DOTALL)
120
- if fence:
121
- try:
122
- return json.loads(fence.group(1))
123
- except json.JSONDecodeError:
124
- pass
125
-
126
- block = re.search(r"({.*?})", stripped, re.DOTALL)
127
- if block:
128
- try:
129
- return json.loads(block.group(1))
130
- except json.JSONDecodeError:
131
- pass
132
-
133
- raise ValueError("Invalid JSON response.")
134
 
135
 
136
  def _build_observation_text(observation: Any) -> str:
@@ -169,7 +156,6 @@ def _get_model_response(
169
  {"role": "user", "content": user_prompt},
170
  ],
171
  "temperature": 0.0,
172
- "max_tokens": THINKING_TOKEN_LIMIT,
173
  "stream": False,
174
  }
175
  try:
 
19
  import asyncio
20
  import json
21
  import os
 
22
  import sys
23
  from pathlib import Path
24
  from typing import Any, Optional
 
45
  BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
46
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
47
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
 
48
  MAX_PARSE_RETRIES = 3
49
 
50
  SYSTEM_PROMPT = (
 
109
 
110
  def _extract_json(text: str) -> dict[str, Any]:
111
  stripped = text.strip()
112
+ if stripped.startswith("```") and stripped.endswith("```"):
113
+ first_newline = stripped.find("\n")
114
+ if first_newline == -1:
115
+ raise ValueError("Invalid JSON response.")
116
+ stripped = stripped[first_newline + 1 : -3].strip()
117
  try:
118
  return json.loads(stripped)
119
+ except json.JSONDecodeError as exc:
120
+ raise ValueError("Invalid JSON response.") from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  def _build_observation_text(observation: Any) -> str:
 
156
  {"role": "user", "content": user_prompt},
157
  ],
158
  "temperature": 0.0,
 
159
  "stream": False,
160
  }
161
  try:
models.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
  from typing import Any, Dict, List, Literal, Optional
6
 
7
  from openenv.core.env_server.types import Action, Observation
8
- from pydantic import BaseModel, Field, model_validator
9
 
10
 
11
  ActionType = Literal[
@@ -21,67 +21,45 @@ ActionType = Literal[
21
  class CodeAction(Action):
22
  """Structured action consumed by the environment."""
23
 
 
 
24
  thought: str = Field(
25
  ...,
26
- description="Mandatory reasoning string before selecting an action.",
 
 
 
 
27
  )
28
  action_type: ActionType = Field(
29
  ...,
30
- description="One of VIEW_CODE, RUN_TESTS, REPLACE_LINES, UNDO_EDIT, RESET_TO_ORIGINAL, SUBMIT.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
- start_line: Optional[int] = Field(default=None)
33
- end_line: Optional[int] = Field(default=None)
34
- new_code_block: Optional[str] = Field(default=None)
35
-
36
- @model_validator(mode="before")
37
- @classmethod
38
- def validate_and_normalize(cls, data: Any) -> Any:
39
- if not isinstance(data, dict):
40
- return data
41
-
42
- action_type = data.get("action_type")
43
-
44
- def _coerce_optional_int(value: Any) -> Optional[int]:
45
- if value is None:
46
- return None
47
- if isinstance(value, int):
48
- return value
49
- if isinstance(value, str):
50
- raw = value.strip()
51
- if raw == "":
52
- return None
53
- try:
54
- return int(raw)
55
- except ValueError:
56
- return None
57
- return None
58
-
59
- data = dict(data)
60
- data["start_line"] = _coerce_optional_int(data.get("start_line"))
61
- data["end_line"] = _coerce_optional_int(data.get("end_line"))
62
-
63
- if action_type == "REPLACE_LINES":
64
- start_line = data.get("start_line")
65
- end_line = data.get("end_line")
66
- new_code_block = data.get("new_code_block")
67
-
68
- if start_line is None:
69
- raise ValueError("REPLACE_LINES requires start_line.")
70
- if end_line is None:
71
- raise ValueError("REPLACE_LINES requires end_line.")
72
- if new_code_block is None:
73
- raise ValueError("REPLACE_LINES requires new_code_block.")
74
- if start_line < 1 or end_line < 1:
75
- raise ValueError("REPLACE_LINES requires start_line and end_line >= 1.")
76
- if start_line > end_line:
77
- raise ValueError("REPLACE_LINES requires start_line <= end_line.")
78
- else:
79
- # Web UI often sends default line fields for non-edit actions.
80
- data["start_line"] = None
81
- data["end_line"] = None
82
- data["new_code_block"] = None
83
-
84
- return data
85
 
86
 
87
  class TestResult(BaseModel):
 
5
  from typing import Any, Dict, List, Literal, Optional
6
 
7
  from openenv.core.env_server.types import Action, Observation
8
+ from pydantic import BaseModel, ConfigDict, Field
9
 
10
 
11
  ActionType = Literal[
 
21
  class CodeAction(Action):
22
  """Structured action consumed by the environment."""
23
 
24
+ model_config = ConfigDict(strict=True)
25
+
26
  thought: str = Field(
27
  ...,
28
+ description=(
29
+ "MANDATORY. Analyze the localized_context and last_execution_output. "
30
+ "If tests failed, identify the error line and root cause. Explicitly plan "
31
+ "your next action before executing it."
32
+ ),
33
  )
34
  action_type: ActionType = Field(
35
  ...,
36
+ description=(
37
+ "The specific tool to use. VIEW_CODE to read. RUN_TESTS to execute and get "
38
+ "tracebacks. REPLACE_LINES to apply a fix. UNDO_EDIT to revert your last "
39
+ "change if it failed. SUBMIT only when all tests pass."
40
+ ),
41
+ )
42
+ start_line: Optional[int] = Field(
43
+ default=None,
44
+ description=(
45
+ "The inclusive start line number for REPLACE_LINES. You MUST use the exact "
46
+ "integer keys provided in the code_dict observation."
47
+ ),
48
+ )
49
+ end_line: Optional[int] = Field(
50
+ default=None,
51
+ description=(
52
+ "The inclusive end line number for REPLACE_LINES. You MUST use the exact "
53
+ "integer keys provided in the code_dict observation."
54
+ ),
55
+ )
56
+ new_code_block: Optional[str] = Field(
57
+ default=None,
58
+ description=(
59
+ "The exact replacement Python code. Must be properly indented to match the "
60
+ "surrounding code. Do not include markdown formatting or backticks."
61
+ ),
62
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  class TestResult(BaseModel):