File size: 5,701 Bytes
5a2d63f
 
 
 
 
2b6814d
 
 
 
 
 
 
 
 
5a2d63f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b6814d
 
 
 
 
 
 
 
 
 
 
 
 
5a2d63f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import Tuple, Dict, Any, Optional
import random

from .models import Action, Observation
from .tasks import ALL_TASKS
from .rewards import (
    compute_total_reward,
    reward_execution_success,
    reward_fix_correctness,
    reward_step_efficiency,
    reward_format_compliance,
    reward_robustness,
    check_anti_hacking_guards,
)
from .memory.failure_bank import FailureMemoryBank

try:
    from openenv import Environment
    _BaseEnv = Environment
except ImportError:
    _BaseEnv = object

class CICDDebugEnv(_BaseEnv):
    def __init__(self):
        self.memory = FailureMemoryBank(store="dict")
        self.current_task = None
        self.episode_history = []
        self.current_observation = None
        self.done = False
        self.step_count = 0
        self.max_steps = 10
        self._state_dict = {}

    def reset(self, task_id: Optional[str] = None) -> Observation:
        if task_id:
            self.current_task = next((t for t in ALL_TASKS if t["id"] == task_id), ALL_TASKS[0])
        else:
            self.current_task = random.choice(ALL_TASKS)
            
        self.episode_history = []
        self.step_count = 0
        self.done = False
        
        self.current_observation = Observation(
            pipeline_yaml=self.current_task["pipeline_yaml"],
            error_message=self.current_task.get("error_message", ""),
            logs=self.current_task.get("logs", []),
            step_blame_scores=self._compute_blame(self.current_task),
            available_actions=self.available_actions(),
            episode_history=[],
            memory_hits=self.memory.query(self.current_task.get("error_message", ""), top_k=2)
        )
        self._update_state()
        return self.current_observation

    def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
        self.step_count += 1
        
        if action.action_type == "edit_config":
            new_yaml = action.parameters.get("new_yaml", action.parameters.get("new_value", ""))
            if new_yaml:
                self.current_observation.pipeline_yaml = new_yaml
                
        if action.action_type == "submit_solution" or self.step_count >= self.max_steps:
            self.done = True
            
        reward = compute_total_reward(self.current_observation, action, self.current_task, max_steps=self.max_steps)
        outcome = "Success" if reward > 0.7 else "Failure"
        
        self.memory.store(
            error_fingerprint=self.current_observation.error_message,
            action=action,
            outcome=outcome,
            reward=reward
        )
        
        history_entry = {
            "action": action, 
            "reward": reward, 
            "outcome": outcome
        }
        self.episode_history.append(history_entry)
        self.current_observation.episode_history = self.episode_history
        self.current_observation.available_actions = self.available_actions()
        
        self._update_state()
        reward_components = {
            "execution_success": reward_execution_success(self.current_observation, self.current_task),
            "fix_correctness":   reward_fix_correctness(self.current_observation, action, self.current_task),
            "step_efficiency":   reward_step_efficiency(self.current_observation, self.max_steps),
            "format_compliance": reward_format_compliance(action),
            "robustness":        reward_robustness(self.current_observation, self.current_task),
            "anti_hacking":      check_anti_hacking_guards(self.current_observation, action),
            "total":             reward,
        }
        return self.current_observation, reward, self.done, {
            "task_id": self.current_task["id"],
            "reward_breakdown": reward_components,
        }

    def state(self) -> dict:
        return self._state_dict

    def available_actions(self) -> list[str]:
        if self.done:
            return []
        return ["read_logs", "analyze_error", "edit_config", "run_tests", "validate_fix", "submit_solution"]

    def render(self) -> str:
        s = f"--- Task: {self.current_task['id']} ---\n"
        s += f"Error: {self.current_observation.error_message}\n"
        s += f"YAML:\n{self.current_observation.pipeline_yaml}\n"
        return s

    def _compute_blame(self, task) -> dict:
        blame_map = {
            "easy_001":   {"build": 0.0, "test": 1.0, "deploy": 0.0},
            "easy_002":   {"build": 0.0, "test": 1.0, "deploy": 0.0},
            "easy_003":   {"build": 0.0, "test": 0.0, "deploy": 1.0},
            "medium_001": {"build": 0.0, "test": 1.0, "deploy": 0.0},
            "medium_002": {"build": 1.0, "test": 0.0, "deploy": 0.0},
            "medium_003": {"build": 0.0, "test": 0.5, "deploy": 0.5},
            "hard_001":   {"build": 0.0, "test": 0.0, "deploy": 1.0},
            "hard_002":   {"build": 0.5, "test": 0.5, "deploy": 0.0},
        }
        return blame_map.get(task.get("id", ""), {"build": 0.33, "test": 0.33, "deploy": 0.34})

    def _update_state(self):
        self._state_dict = {
            "pipeline_yaml": self.current_observation.pipeline_yaml,
            "error_message": self.current_observation.error_message,
            "logs": self.current_observation.logs,
            "step_blame_scores": self.current_observation.step_blame_scores,
            "episode_history": [{"action_type": h["action"].action_type, "reward": h["reward"]} for h in self.episode_history],
            "done": self.done,
            "step_count": self.step_count,
            "task_id": self.current_task["id"] if self.current_task else None
        }