Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import uuid | |
| from typing import Dict, List | |
| from openenv.core.env_server import Environment | |
| from models import DocAction, DocObservation, DocState | |
| class DocSweeperEnvironment(Environment): | |
| def __init__( | |
| self, | |
| task: str = "version_bump", | |
| max_steps: int = 20, | |
| ): | |
| super().__init__(rubric=None) | |
| self._task = task | |
| self._max_steps = max_steps | |
| self._state: DocState | None = None | |
| self._terminal_feedback = "" | |
| self._baseline_denominators = {} | |
| self.reset() | |
| def reset(self, **kwargs): | |
| episode_id = str(uuid.uuid4()) | |
| self._terminal_feedback = "Environment reset." | |
| if self._task == "version_bump": | |
| initial_vfs = { | |
| "/docs/setup.md": "Welcome to our tool v1.0.0. To install v1.0.0, run the script.", | |
| "/docs/api.md": "API Reference for v1.0.0.", | |
| "/docs/troubleshoot.md": "If v1.00 fails, check logs." | |
| } | |
| self._baseline_denominators["total_files"] = 3 | |
| elif self._task == "config_migration": | |
| initial_vfs = { | |
| "/docs/docker-compose.yml": "version: '2'\nservices:\n web:\n links:\n - db", | |
| "/docs/readme.md": "Use the docker-compose to start." | |
| } | |
| self._baseline_denominators["total_files"] = 1 # Only one compose file matters | |
| elif self._task == "broken_links": | |
| initial_vfs = { | |
| "/docs/index.md": "Please read [Setup](../old-docs/setup.md) before continuing.", | |
| "/docs/installation.md": "# Installation\nSee [API](../old-docs/api.md) for details.", | |
| "/docs/advanced.md": "Advanced config in [Setup](../old-docs/setup.md)." | |
| } | |
| self._baseline_denominators["total_links"] = 3 | |
| else: | |
| initial_vfs = {"/docs/empty.md": "Unknown task."} | |
| self._state = DocState( | |
| episode_id=episode_id, | |
| step_count=0, | |
| vfs=initial_vfs, | |
| active_file="" | |
| ) | |
| # Baseline reward is exactly 0.5 (neutral) | |
| return self._make_observation(reward=0.5, done=False) | |
| def step(self, action: DocAction): | |
| if self._state is None: | |
| raise RuntimeError("Environment not initialized. Call reset() first.") | |
| self._state.step_count += 1 | |
| done = False | |
| self._terminal_feedback = "" | |
| old_score = self._calculate_state_score() | |
| step_penalty = 0.0 | |
| if action.tool_name == "done": | |
| done = True | |
| self._terminal_feedback = "Task submitted. Evaluating final state." | |
| elif action.tool_name == "open": | |
| if action.path in self._state.vfs: | |
| self._state.active_file = action.path | |
| self._terminal_feedback = f"Opened {action.path}" | |
| else: | |
| self._terminal_feedback = f"Error: File '{action.path}' not found." | |
| step_penalty -= 0.05 | |
| elif action.tool_name == "grep": | |
| if action.search_query: | |
| results = [p for p, c in self._state.vfs.items() if action.search_query in c] | |
| self._terminal_feedback = f"Found '{action.search_query}' in: {', '.join(results) or 'No files'}" | |
| else: | |
| self._terminal_feedback = "Error: search_query required for grep." | |
| step_penalty -= 0.05 | |
| elif action.tool_name == "edit": | |
| step_penalty += self._handle_edit(action) | |
| else: | |
| self._terminal_feedback = f"Error: Unknown tool {action.tool_name}." | |
| step_penalty -= 0.05 | |
| if self._state.step_count >= self._max_steps and not done: | |
| done = True | |
| self._terminal_feedback = "Max steps reached. Forced termination." | |
| new_score = self._calculate_state_score() | |
| # Calculate raw delta reward | |
| delta_reward = (new_score - old_score) | |
| raw_step_reward = delta_reward + step_penalty | |
| # Map reward to be strictly within (0.0, 1.0) | |
| # raw_step_reward ranges roughly from -1.0 to 1.0. We map it so 0.0 raw = 0.5 mapped. | |
| mapped_reward = (raw_step_reward + 1.0) / 2.0 | |
| # Clamp strictly to (0.0, 1.0) boundaries using a 0.01 epsilon | |
| EPSILON = 0.01 | |
| final_reward = max(EPSILON, min(1.0 - EPSILON, mapped_reward)) | |
| return self._make_observation(reward=final_reward, done=done) | |
| def _handle_edit(self, action: DocAction) -> float: | |
| """Executes the edit and returns a penalty if it fails.""" | |
| if not self._state.active_file: | |
| self._terminal_feedback = "Error: No file is currently open." | |
| return -0.05 | |
| if not action.old_str: | |
| self._terminal_feedback = "Error: 'old_str' is missing or empty." | |
| return -0.05 | |
| content = self._state.vfs[self._state.active_file] | |
| if action.old_str in ["```yaml", "# Title"] and not action.new_str: | |
| self._terminal_feedback = "Error: Destructive action prevented." | |
| return -0.05 | |
| if action.old_str in content: | |
| safe_new_str = action.new_str if action.new_str is not None else "" | |
| self._state.vfs[self._state.active_file] = content.replace(action.old_str, safe_new_str) | |
| self._terminal_feedback = "Edit successful." | |
| return 0.0 | |
| else: | |
| self._terminal_feedback = f"Error: old_str '{action.old_str}' not found in file." | |
| return -0.05 | |
| def _calculate_state_score(self) -> float: | |
| """ | |
| Calculates the absolute progress of the environment [0.0 to 1.0]. | |
| This is called every step to calculate the delta reward. | |
| """ | |
| vfs_items = self._state.vfs.items() | |
| if self._task == "version_bump": | |
| correct_files = 0 | |
| for path, content in vfs_items: | |
| if "v2.0.0" in content and not ("v1.0.0" in content or "v1.00" in content): | |
| correct_files += 1 | |
| return min(1.0, correct_files / self._baseline_denominators["total_files"]) | |
| elif self._task == "config_migration": | |
| compose_files = [content for path, content in vfs_items if "docker-compose" in path] | |
| total_score = 0.0 | |
| for content in compose_files: | |
| if "version: '3.8'" in content or 'version: "3.8"' in content: | |
| total_score += 0.5 | |
| if "networks:" in content and "links:" not in content: | |
| total_score += 0.5 | |
| return min(1.0, total_score / self._baseline_denominators["total_files"]) | |
| elif self._task == "broken_links": | |
| good_link_count = 0 | |
| for path, content in vfs_items: | |
| good_link_count += content.count("./new-docs/") | |
| return min(1.0, good_link_count / self._baseline_denominators["total_links"]) | |
| return 0.0 | |
| def _get_linter_issues(self) -> List[str]: | |
| if not self._state.active_file: | |
| return [] | |
| issues = [] | |
| content = self._state.vfs.get(self._state.active_file, "") | |
| if self._task == "version_bump" and ("v1.0.0" in content or "v1.00" in content): | |
| issues.append("LINTER WARNING: Deprecated version string found.") | |
| elif self._task == "broken_links" and "../old-docs/" in content: | |
| issues.append("LINTER WARNING: Broken relative link detected.") | |
| elif self._task == "config_migration" and "links:" in content: | |
| issues.append("LINTER WARNING: Docker 'links' is deprecated. Use 'networks'.") | |
| return issues | |
| def _make_observation(self, reward: float = 0.0, done: bool = False): | |
| files_list = list(self._state.vfs.keys()) | |
| return DocObservation( | |
| active_file=self._state.active_file, | |
| file_content=self._state.vfs.get(self._state.active_file, ""), | |
| directory_tree={"/docs": files_list}, | |
| issues_detected=self._get_linter_issues(), | |
| terminal_feedback=self._terminal_feedback, | |
| reward=reward, | |
| done=done, | |
| ) | |
| def state(self): | |
| return self._state |