from typing import Tuple, Dict, Any, Optional import random from .models import Action, Observation from .tasks import ALL_TASKS from .rewards import ( compute_total_reward, reward_execution_success, reward_fix_correctness, reward_step_efficiency, reward_format_compliance, reward_robustness, check_anti_hacking_guards, ) from .memory.failure_bank import FailureMemoryBank try: from openenv import Environment _BaseEnv = Environment except ImportError: _BaseEnv = object class CICDDebugEnv(_BaseEnv): def __init__(self): self.memory = FailureMemoryBank(store="dict") self.current_task = None self.episode_history = [] self.current_observation = None self.done = False self.step_count = 0 self.max_steps = 10 self._state_dict = {} def reset(self, task_id: Optional[str] = None) -> Observation: if task_id: self.current_task = next((t for t in ALL_TASKS if t["id"] == task_id), ALL_TASKS[0]) else: self.current_task = random.choice(ALL_TASKS) self.episode_history = [] self.step_count = 0 self.done = False self.current_observation = Observation( pipeline_yaml=self.current_task["pipeline_yaml"], error_message=self.current_task.get("error_message", ""), logs=self.current_task.get("logs", []), step_blame_scores=self._compute_blame(self.current_task), available_actions=self.available_actions(), episode_history=[], memory_hits=self.memory.query(self.current_task.get("error_message", ""), top_k=2) ) self._update_state() return self.current_observation def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]: self.step_count += 1 if action.action_type == "edit_config": new_yaml = action.parameters.get("new_yaml", action.parameters.get("new_value", "")) if new_yaml: self.current_observation.pipeline_yaml = new_yaml if action.action_type == "submit_solution" or self.step_count >= self.max_steps: self.done = True reward = compute_total_reward(self.current_observation, action, self.current_task, max_steps=self.max_steps) outcome = "Success" if reward > 0.7 else "Failure" self.memory.store( error_fingerprint=self.current_observation.error_message, action=action, outcome=outcome, reward=reward ) history_entry = { "action": action, "reward": reward, "outcome": outcome } self.episode_history.append(history_entry) self.current_observation.episode_history = self.episode_history self.current_observation.available_actions = self.available_actions() self._update_state() reward_components = { "execution_success": reward_execution_success(self.current_observation, self.current_task), "fix_correctness": reward_fix_correctness(self.current_observation, action, self.current_task), "step_efficiency": reward_step_efficiency(self.current_observation, self.max_steps), "format_compliance": reward_format_compliance(action), "robustness": reward_robustness(self.current_observation, self.current_task), "anti_hacking": check_anti_hacking_guards(self.current_observation, action), "total": reward, } return self.current_observation, reward, self.done, { "task_id": self.current_task["id"], "reward_breakdown": reward_components, } def state(self) -> dict: return self._state_dict def available_actions(self) -> list[str]: if self.done: return [] return ["read_logs", "analyze_error", "edit_config", "run_tests", "validate_fix", "submit_solution"] def render(self) -> str: s = f"--- Task: {self.current_task['id']} ---\n" s += f"Error: {self.current_observation.error_message}\n" s += f"YAML:\n{self.current_observation.pipeline_yaml}\n" return s def _compute_blame(self, task) -> dict: blame_map = { "easy_001": {"build": 0.0, "test": 1.0, "deploy": 0.0}, "easy_002": {"build": 0.0, "test": 1.0, "deploy": 0.0}, "easy_003": {"build": 0.0, "test": 0.0, "deploy": 1.0}, "medium_001": {"build": 0.0, "test": 1.0, "deploy": 0.0}, "medium_002": {"build": 1.0, "test": 0.0, "deploy": 0.0}, "medium_003": {"build": 0.0, "test": 0.5, "deploy": 0.5}, "hard_001": {"build": 0.0, "test": 0.0, "deploy": 1.0}, "hard_002": {"build": 0.5, "test": 0.5, "deploy": 0.0}, } return blame_map.get(task.get("id", ""), {"build": 0.33, "test": 0.33, "deploy": 0.34}) def _update_state(self): self._state_dict = { "pipeline_yaml": self.current_observation.pipeline_yaml, "error_message": self.current_observation.error_message, "logs": self.current_observation.logs, "step_blame_scores": self.current_observation.step_blame_scores, "episode_history": [{"action_type": h["action"].action_type, "reward": h["reward"]} for h in self.episode_history], "done": self.done, "step_count": self.step_count, "task_id": self.current_task["id"] if self.current_task else None }