File size: 10,497 Bytes
5cf6185 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 | """
API Contract Debugger — OpenEnv Environment
An AI agent receives a broken OpenAPI-style spec and must fix all contract
violations by proposing targeted field-level corrections step-by-step.
"""
from __future__ import annotations
import copy
import uuid
from typing import Any, Dict, List, Optional
from openenv.core.env_server.interfaces import Environment
from .fixtures import TASKS
from .graders import detect_violations, grade_episode, step_reward
from .models import (
ActionKind,
DebugAction,
DebugObservation,
DebugState,
)
class APIContractDebuggerEnv(Environment[DebugAction, DebugObservation, DebugState]):
"""
Environment where an agent debugs broken API contract specifications.
Tasks (difficulty):
easy — 1 endpoint, 1 missing field
medium — 3 endpoints, 3 violations (type errors + wrong status)
hard — 4 endpoints, 6 violations (missing fields, wrong types,
wrong status, forbidden extra field)
Action space:
DebugAction with kind in {add_field, remove_field, change_type,
change_status, no_op}
Observation space:
DebugObservation — current endpoints + violation list + reward signals
Reward:
Dense per-step: +0.2×severity per violation fixed, -0.15×severity per
violation introduced, -0.05 for malformed action.
Episode terminates when all violations are resolved or max_steps reached.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = False
def __init__(self, task_name: str = "easy") -> None:
super().__init__()
if task_name not in TASKS:
raise ValueError(
f"Unknown task '{task_name}'. Choose from: {list(TASKS.keys())}"
)
self._task_name = task_name
self._task_cfg = TASKS[task_name]
# Internal state (populated on reset)
self._current_endpoints: List[Dict[str, Any]] = []
self._golden_endpoints: List[Dict[str, Any]] = []
self._original_endpoints: List[Dict[str, Any]] = []
self._violations: List[Dict[str, Any]] = []
self._initial_violations: List[Dict[str, Any]] = []
self._step_count: int = 0
self._episode_id: Optional[str] = None
self._done: bool = False
# ------------------------------------------------------------------
# OpenEnv API
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs: Any,
) -> DebugObservation:
"""Reset the environment and return the initial observation."""
if task_name and task_name in TASKS:
self._task_name = task_name
self._task_cfg = TASKS[task_name]
self._episode_id = episode_id or str(uuid.uuid4())
self._step_count = 0
self._done = False
# Deep-copy fixtures so mutations don't bleed across episodes
self._current_endpoints = copy.deepcopy(self._task_cfg["broken_endpoints"])
self._golden_endpoints = copy.deepcopy(self._task_cfg["golden_endpoints"])
self._original_endpoints = copy.deepcopy(self._task_cfg["broken_endpoints"])
self._violations = detect_violations(
self._current_endpoints, self._golden_endpoints
)
self._initial_violations = copy.deepcopy(self._violations)
return self._make_observation(
reward=0.0,
done=False,
fixed_this_step=0,
introduced_this_step=0,
action_error=None,
)
def step(
self,
action: DebugAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> DebugObservation:
"""Apply one fix action and return the updated observation."""
if self._done:
return self._make_observation(
reward=0.0,
done=True,
fixed_this_step=0,
introduced_this_step=0,
action_error="Episode is already done. Call reset().",
)
self._step_count += 1
prev_violations = copy.deepcopy(self._violations)
action_error: Optional[str] = None
# --- Apply the action ---
if action.kind == ActionKind.NO_OP:
pass # agent explicitly passes — small implicit penalty via no reward
else:
action_error = self._apply_action(action)
# --- Recompute violations ---
self._violations = detect_violations(
self._current_endpoints, self._golden_endpoints
)
# --- Compute reward ---
reward = step_reward(
prev_violations=prev_violations,
new_violations=self._violations,
initial_violations=self._initial_violations,
action_error=(action_error is not None),
)
fixed_this_step = sum(
1 for v in prev_violations
if v not in self._violations
)
introduced_this_step = sum(
1 for v in self._violations
if v not in prev_violations
)
# --- Termination ---
max_steps = self._task_cfg["max_steps"]
all_fixed = len(self._violations) == 0
out_of_steps = self._step_count >= max_steps
self._done = all_fixed or out_of_steps
# Bonus reward for solving all violations
if all_fixed:
reward += 0.5
return self._make_observation(
reward=reward,
done=self._done,
fixed_this_step=fixed_this_step,
introduced_this_step=introduced_this_step,
action_error=action_error,
)
@property
def state(self) -> DebugState:
"""Return the full internal environment state."""
return DebugState(
episode_id=self._episode_id,
step_count=self._step_count,
task_name=self._task_name,
original_endpoints=self._original_endpoints,
current_endpoints=self._current_endpoints,
golden_endpoints=self._golden_endpoints,
violations=self._violations,
total_violations_at_start=len(self._initial_violations),
max_steps=self._task_cfg["max_steps"],
)
def get_metadata(self):
from openenv.core.env_server.types import EnvironmentMetadata
return EnvironmentMetadata(
name="APIContractDebugger",
description=(
"An environment where an AI agent debugs broken OpenAPI-style "
"contract specifications by proposing targeted field-level fixes."
),
version="1.0.0",
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _apply_action(self, action: DebugAction) -> Optional[str]:
"""
Mutate self._current_endpoints according to the action.
Returns an error string if the action is invalid, else None.
"""
idx = action.endpoint_index
if idx < 0 or idx >= len(self._current_endpoints):
return (
f"endpoint_index {idx} is out of range "
f"(0–{len(self._current_endpoints) - 1})"
)
endpoint = self._current_endpoints[idx]
if action.kind == ActionKind.CHANGE_STATUS:
if not isinstance(action.new_value, int):
return "CHANGE_STATUS requires new_value to be an integer HTTP status code"
endpoint["status_code"] = action.new_value
return None
# For field-level actions, validate location
if action.location not in ("request_body", "response_body"):
return (
f"location must be 'request_body' or 'response_body', "
f"got '{action.location}'"
)
body: Dict[str, Any] = endpoint.setdefault(action.location, {})
field = action.field_name
if action.kind == ActionKind.ADD_FIELD:
if not field:
return "ADD_FIELD requires a non-empty field_name"
if not isinstance(action.new_value, dict) or "type" not in action.new_value:
return "ADD_FIELD requires new_value to be a dict with a 'type' key"
body[field] = action.new_value
return None
if action.kind == ActionKind.REMOVE_FIELD:
if not field:
return "REMOVE_FIELD requires a non-empty field_name"
if field not in body:
return f"field '{field}' does not exist in {action.location}"
del body[field]
return None
if action.kind == ActionKind.CHANGE_TYPE:
if not field:
return "CHANGE_TYPE requires a non-empty field_name"
if field not in body:
return f"field '{field}' does not exist in {action.location}"
if not isinstance(action.new_value, str):
return "CHANGE_TYPE requires new_value to be a type string"
body[field]["type"] = action.new_value
return None
return f"Unknown action kind: {action.kind}"
def _make_observation(
self,
reward: float,
done: bool,
fixed_this_step: int,
introduced_this_step: int,
action_error: Optional[str],
) -> DebugObservation:
return DebugObservation(
task_name=self._task_name,
task_description=self._task_cfg["description"],
endpoints=copy.deepcopy(self._current_endpoints),
violations=copy.deepcopy(self._violations),
violations_fixed_this_step=fixed_this_step,
violations_introduced_this_step=introduced_this_step,
total_violations_at_start=len(self._initial_violations),
step_count=self._step_count,
max_steps=self._task_cfg["max_steps"],
last_action_error=action_error,
reward=reward,
done=done,
)
def score(self) -> float:
"""Final episode score in [0.0, 1.0]. Call after episode ends."""
return grade_episode(
self._current_endpoints,
self._golden_endpoints,
self._initial_violations,
)
|