arjeet commited on
Commit
a8dd45c
·
1 Parent(s): 2f02c40

inference update v4

Browse files
Files changed (2) hide show
  1. inference.py +14 -7
  2. server/cust_env_environment.py +91 -50
inference.py CHANGED
@@ -8,6 +8,7 @@ from models import DocAction
8
  IMAGE_NAME = os.getenv("IMAGE_NAME")
9
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
 
 
11
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
12
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
13
  BENCHMARK_NAME = "doc_sweeper"
@@ -24,9 +25,12 @@ def run_inference(task_name: str):
24
  if not hf_token:
25
  raise ValueError("Missing hf_token")
26
 
 
27
  client = OpenAI(
28
  api_key=hf_token,
29
- base_url=api_base_url
 
 
30
  )
31
 
32
  env = DocSweeperEnvironment(task=task_name)
@@ -36,6 +40,7 @@ def run_inference(task_name: str):
36
  total_reward = 0.0
37
  step_count = 0
38
  rewards_history = []
 
39
 
40
  print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
41
 
@@ -44,11 +49,11 @@ def run_inference(task_name: str):
44
 
45
  YOUR CURRENT TASK: '{task_name}'
46
  - If 'version_bump': Systematically OPEN EVERY SINGLE FILE in the directory tree. Check for 'v1.0.0' or 'v1.00'. If found, use 'edit' to update to 'v2.0.0'.
47
- - If 'config_migration': Open docker-compose files. Update version to 3.8 and migrate 'links' to 'networks'.
48
- - If 'broken_links': Find broken relative links and edit them to point to correct paths.
49
 
50
  WORKFLOW RULES:
51
- 1. PLAN FIRST: Use the 'thought' field to track which files you have checked and which remain.
52
  2. OPEN THEN EDIT: You MUST 'open' a file before you can 'edit' it.
53
  3. EDIT SAFELY: When editing, use 'old_str' (exact text to replace) and 'new_str'. Do NOT use 'path'.
54
  4. FINISH: Call 'done' ONLY when you have opened and verified EVERY file in the directory tree.
@@ -66,10 +71,9 @@ def run_inference(task_name: str):
66
  """
67
 
68
  messages = [{"role": "system", "content": system_prompt}]
69
-
70
  start_time = time.time()
71
 
72
- while not done:
73
  step_count += 1
74
  current_state_prompt = f"""
75
  [ENVIRONMENT OBSERVATION]
@@ -116,11 +120,14 @@ def run_inference(task_name: str):
116
  obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
117
  rewards_history.append(0.0)
118
 
 
 
 
119
  done_str = str(done).lower()
120
  print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
121
 
122
  final_score = max(0.0, min(1.0, total_reward))
123
- success = final_score > 0.0 # Define what success means for your environment
124
  success_str = str(success).lower()
125
  rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
126
 
 
8
  IMAGE_NAME = os.getenv("IMAGE_NAME")
9
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
 
11
+ # Swapped back to OpenAI defaults
12
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
13
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
14
  BENCHMARK_NAME = "doc_sweeper"
 
25
  if not hf_token:
26
  raise ValueError("Missing hf_token")
27
 
28
+ # Replaced Groq with OpenAI, keeping the timeout fixes!
29
  client = OpenAI(
30
  api_key=hf_token,
31
+ base_url=api_base_url,
32
+ timeout=15.0, # Max 15 seconds per request
33
+ max_retries=1 # Do not get stuck in infinite backoff loops
34
  )
35
 
36
  env = DocSweeperEnvironment(task=task_name)
 
40
  total_reward = 0.0
41
  step_count = 0
42
  rewards_history = []
43
+ MAX_STEPS = 20 # Hard step limit failsafe
44
 
45
  print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
46
 
 
49
 
50
  YOUR CURRENT TASK: '{task_name}'
51
  - If 'version_bump': Systematically OPEN EVERY SINGLE FILE in the directory tree. Check for 'v1.0.0' or 'v1.00'. If found, use 'edit' to update to 'v2.0.0'.
52
+ - If 'config_migration': Open docker-compose files. Update version to '3.8' and migrate 'links:' to 'networks:'.
53
+ - If 'broken_links': Find broken relative links containing '../old-docs/' and edit them to point strictly to './new-docs/'.
54
 
55
  WORKFLOW RULES:
56
+ 1. PLAN IN THOUGHT: Use the 'thought' field to reason. NEVER use a tool called "plan". Valid tools are strictly: 'open', 'edit', 'grep', 'done'.
57
  2. OPEN THEN EDIT: You MUST 'open' a file before you can 'edit' it.
58
  3. EDIT SAFELY: When editing, use 'old_str' (exact text to replace) and 'new_str'. Do NOT use 'path'.
59
  4. FINISH: Call 'done' ONLY when you have opened and verified EVERY file in the directory tree.
 
71
  """
72
 
73
  messages = [{"role": "system", "content": system_prompt}]
 
74
  start_time = time.time()
75
 
76
+ while not done and step_count < MAX_STEPS:
77
  step_count += 1
78
  current_state_prompt = f"""
79
  [ENVIRONMENT OBSERVATION]
 
120
  obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
121
  rewards_history.append(0.0)
122
 
123
+ if "timeout" in error_msg.lower() or "connection" in error_msg.lower():
124
+ done = True
125
+
126
  done_str = str(done).lower()
127
  print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
128
 
129
  final_score = max(0.0, min(1.0, total_reward))
130
+ success = final_score > 0.0
131
  success_str = str(success).lower()
132
  rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
133
 
server/cust_env_environment.py CHANGED
@@ -8,55 +8,52 @@ import uuid
8
  from typing import Dict, List
9
 
10
  from openenv.core.env_server import Environment
11
-
12
  from models import DocAction, DocObservation, DocState
13
 
14
-
15
  class DocSweeperEnvironment(Environment):
16
 
17
  def __init__(
18
  self,
19
  task: str = "version_bump",
20
- max_steps: int = 30,
21
  ):
22
- """
23
- Args:
24
- task: Task to run - "version_bump", "config_migration", or "broken_links".
25
- max_steps: Maximum allowed actions before forced termination.
26
- """
27
  super().__init__(rubric=None)
28
  self._task = task
29
  self._max_steps = max_steps
30
  self._state: DocState | None = None
31
  self._terminal_feedback = ""
 
 
32
  self.reset()
33
 
34
  def reset(self, **kwargs):
35
- """
36
- Returns:
37
- Initial observation of the virtual file system.
38
- """
39
  episode_id = str(uuid.uuid4())
40
  self._terminal_feedback = "Environment reset."
41
 
42
- initial_vfs = {}
43
  if self._task == "version_bump":
44
  initial_vfs = {
45
  "/docs/setup.md": "Welcome to our tool v1.0.0. To install v1.0.0, run the script.",
46
  "/docs/api.md": "API Reference for v1.0.0.",
47
  "/docs/troubleshoot.md": "If v1.00 fails, check logs."
48
  }
 
 
49
  elif self._task == "config_migration":
50
  initial_vfs = {
51
  "/docs/docker-compose.yml": "version: '2'\nservices:\n web:\n links:\n - db",
52
  "/docs/readme.md": "Use the docker-compose to start."
53
  }
54
- else:
 
 
55
  initial_vfs = {
56
- "/docs/index.md": "Please read [Setup](setup.md) before continuing.",
57
- "/docs/installation.md": "# Installation\nSteps go here.",
58
- "/docs/advanced.md": "Advanced config in [Setup](setup.md)."
59
  }
 
 
 
60
 
61
  self._state = DocState(
62
  episode_id=episode_id,
@@ -68,96 +65,140 @@ class DocSweeperEnvironment(Environment):
68
  return self._make_observation(reward=0.0, done=False)
69
 
70
  def step(self, action: DocAction):
71
- """
72
- Args:
73
- action: The tool action to execute (open, edit, grep, done).
74
- """
75
  if self._state is None:
76
  raise RuntimeError("Environment not initialized. Call reset() first.")
77
 
78
  self._state.step_count += 1
79
- reward = 0.0
80
  done = False
81
  self._terminal_feedback = ""
82
 
83
- # Action Routing
 
 
 
 
 
84
  if action.tool_name == "done":
85
  done = True
86
- reward += self._evaluate_final_grade()
87
- self._terminal_feedback = "Task submitted for final grading."
88
 
89
  elif action.tool_name == "open":
90
  if action.path in self._state.vfs:
91
  self._state.active_file = action.path
92
  self._terminal_feedback = f"Opened {action.path}"
93
  else:
94
- self._terminal_feedback = f"Error: File {action.path} not found."
95
- reward -= 0.1
96
 
97
  elif action.tool_name == "grep":
98
  if action.search_query:
99
  results = [p for p, c in self._state.vfs.items() if action.search_query in c]
100
  self._terminal_feedback = f"Found '{action.search_query}' in: {', '.join(results) or 'No files'}"
101
- if self._task == "broken_links":
102
- reward += 0.1
103
  else:
104
  self._terminal_feedback = "Error: search_query required for grep."
 
105
 
106
  elif action.tool_name == "edit":
107
- reward += self._handle_edit(action)
108
 
109
  else:
110
  self._terminal_feedback = f"Error: Unknown tool {action.tool_name}."
111
- reward -= 0.1
112
 
113
- # Check timeout
114
- if self._state.step_count >= self._max_steps:
115
  done = True
116
- self._terminal_feedback = "Max steps reached."
 
 
 
 
 
117
 
118
- return self._make_observation(reward=reward, done=done)
119
 
120
  def _handle_edit(self, action: DocAction) -> float:
 
121
  if not self._state.active_file:
122
  self._terminal_feedback = "Error: No file is currently open."
123
- return -0.1
124
 
 
 
 
 
125
  content = self._state.vfs[self._state.active_file]
126
 
127
  if action.old_str in ["```yaml", "# Title"] and not action.new_str:
128
  self._terminal_feedback = "Error: Destructive action prevented."
129
- return -1.0
130
 
131
- if action.old_str and action.old_str in content:
132
- self._state.vfs[self._state.active_file] = content.replace(action.old_str, action.new_str or "")
 
133
  self._terminal_feedback = "Edit successful."
134
- return 0.1
135
  else:
136
  self._terminal_feedback = f"Error: old_str '{action.old_str}' not found in file."
137
- return -0.1
138
 
139
- def _evaluate_final_grade(self) -> float:
140
- text = "".join(self._state.vfs.values())
 
 
 
 
 
141
  if self._task == "version_bump":
142
- target_count = text.count("v2.0.0")
143
- penalty = text.count("v1.0.0") + text.count("v1.00")
144
- return max(0.0, (target_count / 4.0) - (penalty * 0.5))
145
- return 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def _get_linter_issues(self) -> List[str]:
148
  if not self._state.active_file:
149
  return []
150
  issues = []
151
  content = self._state.vfs.get(self._state.active_file, "")
152
- if self._task == "version_bump" and "v1.0.0" in content:
153
- issues.append("Deprecated version 'v1.0.0' found.")
 
 
 
 
 
 
154
  return issues
155
 
156
  def _make_observation(self, reward: float = 0.0, done: bool = False):
 
157
  return DocObservation(
158
  active_file=self._state.active_file,
159
  file_content=self._state.vfs.get(self._state.active_file, ""),
160
- directory_tree={"/docs": list(self._state.vfs.keys())},
161
  issues_detected=self._get_linter_issues(),
162
  terminal_feedback=self._terminal_feedback,
163
  reward=reward,
 
8
  from typing import Dict, List
9
 
10
  from openenv.core.env_server import Environment
 
11
  from models import DocAction, DocObservation, DocState
12
 
 
13
  class DocSweeperEnvironment(Environment):
14
 
15
  def __init__(
16
  self,
17
  task: str = "version_bump",
18
+ max_steps: int = 20,
19
  ):
 
 
 
 
 
20
  super().__init__(rubric=None)
21
  self._task = task
22
  self._max_steps = max_steps
23
  self._state: DocState | None = None
24
  self._terminal_feedback = ""
25
+
26
+ self._baseline_denominators = {}
27
  self.reset()
28
 
29
  def reset(self, **kwargs):
 
 
 
 
30
  episode_id = str(uuid.uuid4())
31
  self._terminal_feedback = "Environment reset."
32
 
 
33
  if self._task == "version_bump":
34
  initial_vfs = {
35
  "/docs/setup.md": "Welcome to our tool v1.0.0. To install v1.0.0, run the script.",
36
  "/docs/api.md": "API Reference for v1.0.0.",
37
  "/docs/troubleshoot.md": "If v1.00 fails, check logs."
38
  }
39
+ self._baseline_denominators["total_files"] = 3
40
+
41
  elif self._task == "config_migration":
42
  initial_vfs = {
43
  "/docs/docker-compose.yml": "version: '2'\nservices:\n web:\n links:\n - db",
44
  "/docs/readme.md": "Use the docker-compose to start."
45
  }
46
+ self._baseline_denominators["total_files"] = 1 # Only one compose file matters
47
+
48
+ elif self._task == "broken_links":
49
  initial_vfs = {
50
+ "/docs/index.md": "Please read [Setup](../old-docs/setup.md) before continuing.",
51
+ "/docs/installation.md": "# Installation\nSee [API](../old-docs/api.md) for details.",
52
+ "/docs/advanced.md": "Advanced config in [Setup](../old-docs/setup.md)."
53
  }
54
+ self._baseline_denominators["total_links"] = 3
55
+ else:
56
+ initial_vfs = {"/docs/empty.md": "Unknown task."}
57
 
58
  self._state = DocState(
59
  episode_id=episode_id,
 
65
  return self._make_observation(reward=0.0, done=False)
66
 
67
  def step(self, action: DocAction):
 
 
 
 
68
  if self._state is None:
69
  raise RuntimeError("Environment not initialized. Call reset() first.")
70
 
71
  self._state.step_count += 1
 
72
  done = False
73
  self._terminal_feedback = ""
74
 
75
+ # 1. Calculate the score BEFORE the action
76
+ old_score = self._calculate_state_score()
77
+
78
+ # 2. Execute the action and track any direct penalties (syntax errors, bad paths)
79
+ step_penalty = 0.0
80
+
81
  if action.tool_name == "done":
82
  done = True
83
+ self._terminal_feedback = "Task submitted. Evaluating final state."
84
+ # No direct penalty or bonus here, the final delta will handle it
85
 
86
  elif action.tool_name == "open":
87
  if action.path in self._state.vfs:
88
  self._state.active_file = action.path
89
  self._terminal_feedback = f"Opened {action.path}"
90
  else:
91
+ self._terminal_feedback = f"Error: File '{action.path}' not found."
92
+ step_penalty -= 0.05 # Small penalty for hallucinating files
93
 
94
  elif action.tool_name == "grep":
95
  if action.search_query:
96
  results = [p for p, c in self._state.vfs.items() if action.search_query in c]
97
  self._terminal_feedback = f"Found '{action.search_query}' in: {', '.join(results) or 'No files'}"
 
 
98
  else:
99
  self._terminal_feedback = "Error: search_query required for grep."
100
+ step_penalty -= 0.05
101
 
102
  elif action.tool_name == "edit":
103
+ step_penalty += self._handle_edit(action)
104
 
105
  else:
106
  self._terminal_feedback = f"Error: Unknown tool {action.tool_name}."
107
+ step_penalty -= 0.05
108
 
109
+ if self._state.step_count >= self._max_steps and not done:
 
110
  done = True
111
+ self._terminal_feedback = "Max steps reached. Forced termination."
112
+
113
+ new_score = self._calculate_state_score()
114
+
115
+ delta_reward = (new_score - old_score)
116
+ total_step_reward = delta_reward + step_penalty
117
 
118
+ return self._make_observation(reward=total_step_reward, done=done)
119
 
120
  def _handle_edit(self, action: DocAction) -> float:
121
+ """Executes the edit and returns a penalty if it fails."""
122
  if not self._state.active_file:
123
  self._terminal_feedback = "Error: No file is currently open."
124
+ return -0.05
125
 
126
+ if not action.old_str:
127
+ self._terminal_feedback = "Error: 'old_str' is missing or empty."
128
+ return -0.05
129
+
130
  content = self._state.vfs[self._state.active_file]
131
 
132
  if action.old_str in ["```yaml", "# Title"] and not action.new_str:
133
  self._terminal_feedback = "Error: Destructive action prevented."
134
+ return -0.05
135
 
136
+ if action.old_str in content:
137
+ safe_new_str = action.new_str if action.new_str is not None else ""
138
+ self._state.vfs[self._state.active_file] = content.replace(action.old_str, safe_new_str)
139
  self._terminal_feedback = "Edit successful."
140
+ return 0.0
141
  else:
142
  self._terminal_feedback = f"Error: old_str '{action.old_str}' not found in file."
143
+ return -0.05
144
 
145
+ def _calculate_state_score(self) -> float:
146
+ """
147
+ Calculates the absolute progress of the environment [0.0 to 1.0].
148
+ This is called every step to calculate the delta reward.
149
+ """
150
+ vfs_items = self._state.vfs.items()
151
+
152
  if self._task == "version_bump":
153
+ correct_files = 0
154
+ for path, content in vfs_items:
155
+ if "v2.0.0" in content and not ("v1.0.0" in content or "v1.00" in content):
156
+ correct_files += 1
157
+
158
+ return min(1.0, correct_files / self._baseline_denominators["total_files"])
159
+
160
+ elif self._task == "config_migration":
161
+ compose_files = [content for path, content in vfs_items if "docker-compose" in path]
162
+ total_score = 0.0
163
+
164
+ for content in compose_files:
165
+ if "version: '3.8'" in content or 'version: "3.8"' in content:
166
+ total_score += 0.5
167
+ if "networks:" in content and "links:" not in content:
168
+ total_score += 0.5
169
+
170
+ return min(1.0, total_score / self._baseline_denominators["total_files"])
171
+
172
+ elif self._task == "broken_links":
173
+ good_link_count = 0
174
+ for path, content in vfs_items:
175
+ good_link_count += content.count("./new-docs/")
176
+
177
+ return min(1.0, good_link_count / self._baseline_denominators["total_links"])
178
+
179
+ return 0.0
180
 
181
  def _get_linter_issues(self) -> List[str]:
182
  if not self._state.active_file:
183
  return []
184
  issues = []
185
  content = self._state.vfs.get(self._state.active_file, "")
186
+
187
+ if self._task == "version_bump" and ("v1.0.0" in content or "v1.00" in content):
188
+ issues.append("LINTER WARNING: Deprecated version string found.")
189
+ elif self._task == "broken_links" and "../old-docs/" in content:
190
+ issues.append("LINTER WARNING: Broken relative link detected.")
191
+ elif self._task == "config_migration" and "links:" in content:
192
+ issues.append("LINTER WARNING: Docker 'links' is deprecated. Use 'networks'.")
193
+
194
  return issues
195
 
196
  def _make_observation(self, reward: float = 0.0, done: bool = False):
197
+ files_list = list(self._state.vfs.keys())
198
  return DocObservation(
199
  active_file=self._state.active_file,
200
  file_content=self._state.vfs.get(self._state.active_file, ""),
201
+ directory_tree={"/docs": files_list},
202
  issues_detected=self._get_linter_issues(),
203
  terminal_feedback=self._terminal_feedback,
204
  reward=reward,