Nikitasoni22's picture
updated code
2b6814d
from typing import Tuple, Dict, Any, Optional
import random
from .models import Action, Observation
from .tasks import ALL_TASKS
from .rewards import (
compute_total_reward,
reward_execution_success,
reward_fix_correctness,
reward_step_efficiency,
reward_format_compliance,
reward_robustness,
check_anti_hacking_guards,
)
from .memory.failure_bank import FailureMemoryBank
try:
from openenv import Environment
_BaseEnv = Environment
except ImportError:
_BaseEnv = object
class CICDDebugEnv(_BaseEnv):
def __init__(self):
self.memory = FailureMemoryBank(store="dict")
self.current_task = None
self.episode_history = []
self.current_observation = None
self.done = False
self.step_count = 0
self.max_steps = 10
self._state_dict = {}
def reset(self, task_id: Optional[str] = None) -> Observation:
if task_id:
self.current_task = next((t for t in ALL_TASKS if t["id"] == task_id), ALL_TASKS[0])
else:
self.current_task = random.choice(ALL_TASKS)
self.episode_history = []
self.step_count = 0
self.done = False
self.current_observation = Observation(
pipeline_yaml=self.current_task["pipeline_yaml"],
error_message=self.current_task.get("error_message", ""),
logs=self.current_task.get("logs", []),
step_blame_scores=self._compute_blame(self.current_task),
available_actions=self.available_actions(),
episode_history=[],
memory_hits=self.memory.query(self.current_task.get("error_message", ""), top_k=2)
)
self._update_state()
return self.current_observation
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
self.step_count += 1
if action.action_type == "edit_config":
new_yaml = action.parameters.get("new_yaml", action.parameters.get("new_value", ""))
if new_yaml:
self.current_observation.pipeline_yaml = new_yaml
if action.action_type == "submit_solution" or self.step_count >= self.max_steps:
self.done = True
reward = compute_total_reward(self.current_observation, action, self.current_task, max_steps=self.max_steps)
outcome = "Success" if reward > 0.7 else "Failure"
self.memory.store(
error_fingerprint=self.current_observation.error_message,
action=action,
outcome=outcome,
reward=reward
)
history_entry = {
"action": action,
"reward": reward,
"outcome": outcome
}
self.episode_history.append(history_entry)
self.current_observation.episode_history = self.episode_history
self.current_observation.available_actions = self.available_actions()
self._update_state()
reward_components = {
"execution_success": reward_execution_success(self.current_observation, self.current_task),
"fix_correctness": reward_fix_correctness(self.current_observation, action, self.current_task),
"step_efficiency": reward_step_efficiency(self.current_observation, self.max_steps),
"format_compliance": reward_format_compliance(action),
"robustness": reward_robustness(self.current_observation, self.current_task),
"anti_hacking": check_anti_hacking_guards(self.current_observation, action),
"total": reward,
}
return self.current_observation, reward, self.done, {
"task_id": self.current_task["id"],
"reward_breakdown": reward_components,
}
def state(self) -> dict:
return self._state_dict
def available_actions(self) -> list[str]:
if self.done:
return []
return ["read_logs", "analyze_error", "edit_config", "run_tests", "validate_fix", "submit_solution"]
def render(self) -> str:
s = f"--- Task: {self.current_task['id']} ---\n"
s += f"Error: {self.current_observation.error_message}\n"
s += f"YAML:\n{self.current_observation.pipeline_yaml}\n"
return s
def _compute_blame(self, task) -> dict:
blame_map = {
"easy_001": {"build": 0.0, "test": 1.0, "deploy": 0.0},
"easy_002": {"build": 0.0, "test": 1.0, "deploy": 0.0},
"easy_003": {"build": 0.0, "test": 0.0, "deploy": 1.0},
"medium_001": {"build": 0.0, "test": 1.0, "deploy": 0.0},
"medium_002": {"build": 1.0, "test": 0.0, "deploy": 0.0},
"medium_003": {"build": 0.0, "test": 0.5, "deploy": 0.5},
"hard_001": {"build": 0.0, "test": 0.0, "deploy": 1.0},
"hard_002": {"build": 0.5, "test": 0.5, "deploy": 0.0},
}
return blame_map.get(task.get("id", ""), {"build": 0.33, "test": 0.33, "deploy": 0.34})
def _update_state(self):
self._state_dict = {
"pipeline_yaml": self.current_observation.pipeline_yaml,
"error_message": self.current_observation.error_message,
"logs": self.current_observation.logs,
"step_blame_scores": self.current_observation.step_blame_scores,
"episode_history": [{"action_type": h["action"].action_type, "reward": h["reward"]} for h in self.episode_history],
"done": self.done,
"step_count": self.step_count,
"task_id": self.current_task["id"] if self.current_task else None
}