| from typing import Tuple, Dict, Any, Optional |
| import random |
|
|
| from .models import Action, Observation |
| from .tasks import ALL_TASKS |
| from .rewards import ( |
| compute_total_reward, |
| reward_execution_success, |
| reward_fix_correctness, |
| reward_step_efficiency, |
| reward_format_compliance, |
| reward_robustness, |
| check_anti_hacking_guards, |
| ) |
| from .memory.failure_bank import FailureMemoryBank |
|
|
| try: |
| from openenv import Environment |
| _BaseEnv = Environment |
| except ImportError: |
| _BaseEnv = object |
|
|
| class CICDDebugEnv(_BaseEnv): |
| def __init__(self): |
| self.memory = FailureMemoryBank(store="dict") |
| self.current_task = None |
| self.episode_history = [] |
| self.current_observation = None |
| self.done = False |
| self.step_count = 0 |
| self.max_steps = 10 |
| self._state_dict = {} |
|
|
| def reset(self, task_id: Optional[str] = None) -> Observation: |
| if task_id: |
| self.current_task = next((t for t in ALL_TASKS if t["id"] == task_id), ALL_TASKS[0]) |
| else: |
| self.current_task = random.choice(ALL_TASKS) |
| |
| self.episode_history = [] |
| self.step_count = 0 |
| self.done = False |
| |
| self.current_observation = Observation( |
| pipeline_yaml=self.current_task["pipeline_yaml"], |
| error_message=self.current_task.get("error_message", ""), |
| logs=self.current_task.get("logs", []), |
| step_blame_scores=self._compute_blame(self.current_task), |
| available_actions=self.available_actions(), |
| episode_history=[], |
| memory_hits=self.memory.query(self.current_task.get("error_message", ""), top_k=2) |
| ) |
| self._update_state() |
| return self.current_observation |
|
|
| def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]: |
| self.step_count += 1 |
| |
| if action.action_type == "edit_config": |
| new_yaml = action.parameters.get("new_yaml", action.parameters.get("new_value", "")) |
| if new_yaml: |
| self.current_observation.pipeline_yaml = new_yaml |
| |
| if action.action_type == "submit_solution" or self.step_count >= self.max_steps: |
| self.done = True |
| |
| reward = compute_total_reward(self.current_observation, action, self.current_task, max_steps=self.max_steps) |
| outcome = "Success" if reward > 0.7 else "Failure" |
| |
| self.memory.store( |
| error_fingerprint=self.current_observation.error_message, |
| action=action, |
| outcome=outcome, |
| reward=reward |
| ) |
| |
| history_entry = { |
| "action": action, |
| "reward": reward, |
| "outcome": outcome |
| } |
| self.episode_history.append(history_entry) |
| self.current_observation.episode_history = self.episode_history |
| self.current_observation.available_actions = self.available_actions() |
| |
| self._update_state() |
| reward_components = { |
| "execution_success": reward_execution_success(self.current_observation, self.current_task), |
| "fix_correctness": reward_fix_correctness(self.current_observation, action, self.current_task), |
| "step_efficiency": reward_step_efficiency(self.current_observation, self.max_steps), |
| "format_compliance": reward_format_compliance(action), |
| "robustness": reward_robustness(self.current_observation, self.current_task), |
| "anti_hacking": check_anti_hacking_guards(self.current_observation, action), |
| "total": reward, |
| } |
| return self.current_observation, reward, self.done, { |
| "task_id": self.current_task["id"], |
| "reward_breakdown": reward_components, |
| } |
|
|
| def state(self) -> dict: |
| return self._state_dict |
|
|
| def available_actions(self) -> list[str]: |
| if self.done: |
| return [] |
| return ["read_logs", "analyze_error", "edit_config", "run_tests", "validate_fix", "submit_solution"] |
|
|
| def render(self) -> str: |
| s = f"--- Task: {self.current_task['id']} ---\n" |
| s += f"Error: {self.current_observation.error_message}\n" |
| s += f"YAML:\n{self.current_observation.pipeline_yaml}\n" |
| return s |
|
|
| def _compute_blame(self, task) -> dict: |
| blame_map = { |
| "easy_001": {"build": 0.0, "test": 1.0, "deploy": 0.0}, |
| "easy_002": {"build": 0.0, "test": 1.0, "deploy": 0.0}, |
| "easy_003": {"build": 0.0, "test": 0.0, "deploy": 1.0}, |
| "medium_001": {"build": 0.0, "test": 1.0, "deploy": 0.0}, |
| "medium_002": {"build": 1.0, "test": 0.0, "deploy": 0.0}, |
| "medium_003": {"build": 0.0, "test": 0.5, "deploy": 0.5}, |
| "hard_001": {"build": 0.0, "test": 0.0, "deploy": 1.0}, |
| "hard_002": {"build": 0.5, "test": 0.5, "deploy": 0.0}, |
| } |
| return blame_map.get(task.get("id", ""), {"build": 0.33, "test": 0.33, "deploy": 0.34}) |
|
|
| def _update_state(self): |
| self._state_dict = { |
| "pipeline_yaml": self.current_observation.pipeline_yaml, |
| "error_message": self.current_observation.error_message, |
| "logs": self.current_observation.logs, |
| "step_blame_scores": self.current_observation.step_blame_scores, |
| "episode_history": [{"action_type": h["action"].action_type, "reward": h["reward"]} for h in self.episode_history], |
| "done": self.done, |
| "step_count": self.step_count, |
| "task_id": self.current_task["id"] if self.current_task else None |
| } |
|
|