File size: 5,701 Bytes
5a2d63f 2b6814d 5a2d63f 2b6814d 5a2d63f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from typing import Tuple, Dict, Any, Optional
import random
from .models import Action, Observation
from .tasks import ALL_TASKS
from .rewards import (
compute_total_reward,
reward_execution_success,
reward_fix_correctness,
reward_step_efficiency,
reward_format_compliance,
reward_robustness,
check_anti_hacking_guards,
)
from .memory.failure_bank import FailureMemoryBank
try:
from openenv import Environment
_BaseEnv = Environment
except ImportError:
_BaseEnv = object
class CICDDebugEnv(_BaseEnv):
def __init__(self):
self.memory = FailureMemoryBank(store="dict")
self.current_task = None
self.episode_history = []
self.current_observation = None
self.done = False
self.step_count = 0
self.max_steps = 10
self._state_dict = {}
def reset(self, task_id: Optional[str] = None) -> Observation:
if task_id:
self.current_task = next((t for t in ALL_TASKS if t["id"] == task_id), ALL_TASKS[0])
else:
self.current_task = random.choice(ALL_TASKS)
self.episode_history = []
self.step_count = 0
self.done = False
self.current_observation = Observation(
pipeline_yaml=self.current_task["pipeline_yaml"],
error_message=self.current_task.get("error_message", ""),
logs=self.current_task.get("logs", []),
step_blame_scores=self._compute_blame(self.current_task),
available_actions=self.available_actions(),
episode_history=[],
memory_hits=self.memory.query(self.current_task.get("error_message", ""), top_k=2)
)
self._update_state()
return self.current_observation
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
self.step_count += 1
if action.action_type == "edit_config":
new_yaml = action.parameters.get("new_yaml", action.parameters.get("new_value", ""))
if new_yaml:
self.current_observation.pipeline_yaml = new_yaml
if action.action_type == "submit_solution" or self.step_count >= self.max_steps:
self.done = True
reward = compute_total_reward(self.current_observation, action, self.current_task, max_steps=self.max_steps)
outcome = "Success" if reward > 0.7 else "Failure"
self.memory.store(
error_fingerprint=self.current_observation.error_message,
action=action,
outcome=outcome,
reward=reward
)
history_entry = {
"action": action,
"reward": reward,
"outcome": outcome
}
self.episode_history.append(history_entry)
self.current_observation.episode_history = self.episode_history
self.current_observation.available_actions = self.available_actions()
self._update_state()
reward_components = {
"execution_success": reward_execution_success(self.current_observation, self.current_task),
"fix_correctness": reward_fix_correctness(self.current_observation, action, self.current_task),
"step_efficiency": reward_step_efficiency(self.current_observation, self.max_steps),
"format_compliance": reward_format_compliance(action),
"robustness": reward_robustness(self.current_observation, self.current_task),
"anti_hacking": check_anti_hacking_guards(self.current_observation, action),
"total": reward,
}
return self.current_observation, reward, self.done, {
"task_id": self.current_task["id"],
"reward_breakdown": reward_components,
}
def state(self) -> dict:
return self._state_dict
def available_actions(self) -> list[str]:
if self.done:
return []
return ["read_logs", "analyze_error", "edit_config", "run_tests", "validate_fix", "submit_solution"]
def render(self) -> str:
s = f"--- Task: {self.current_task['id']} ---\n"
s += f"Error: {self.current_observation.error_message}\n"
s += f"YAML:\n{self.current_observation.pipeline_yaml}\n"
return s
def _compute_blame(self, task) -> dict:
blame_map = {
"easy_001": {"build": 0.0, "test": 1.0, "deploy": 0.0},
"easy_002": {"build": 0.0, "test": 1.0, "deploy": 0.0},
"easy_003": {"build": 0.0, "test": 0.0, "deploy": 1.0},
"medium_001": {"build": 0.0, "test": 1.0, "deploy": 0.0},
"medium_002": {"build": 1.0, "test": 0.0, "deploy": 0.0},
"medium_003": {"build": 0.0, "test": 0.5, "deploy": 0.5},
"hard_001": {"build": 0.0, "test": 0.0, "deploy": 1.0},
"hard_002": {"build": 0.5, "test": 0.5, "deploy": 0.0},
}
return blame_map.get(task.get("id", ""), {"build": 0.33, "test": 0.33, "deploy": 0.34})
def _update_state(self):
self._state_dict = {
"pipeline_yaml": self.current_observation.pipeline_yaml,
"error_message": self.current_observation.error_message,
"logs": self.current_observation.logs,
"step_blame_scores": self.current_observation.step_blame_scores,
"episode_history": [{"action_type": h["action"].action_type, "reward": h["reward"]} for h in self.episode_history],
"done": self.done,
"step_count": self.step_count,
"task_id": self.current_task["id"] if self.current_task else None
}
|