File size: 7,205 Bytes
5cf6185 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """
Violation detection and graders for the API Contract Debugger environment.
detect_violations(current, golden) β list of violation dicts
grade_episode(current, golden) β float in [0.0, 1.0]
"""
from __future__ import annotations
import copy
from typing import Any, Dict, List
# ---------------------------------------------------------------------------
# Violation detection
# ---------------------------------------------------------------------------
def detect_violations(
current_endpoints: List[Dict[str, Any]],
golden_endpoints: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Compare current spec against the golden spec and return all violations.
Violation dict keys:
endpoint_index int β index into endpoint list
location str β "request_body" | "response_body" | "status_code"
field_name str|None
violation_type str β "missing_field" | "extra_field" | "wrong_type" | "wrong_status"
description str β human-readable explanation
severity float β weight used in scoring (0.0β1.0)
"""
violations: List[Dict[str, Any]] = []
for idx, (cur, gold) in enumerate(zip(current_endpoints, golden_endpoints)):
# --- Status code ---
if cur.get("status_code") != gold.get("status_code"):
violations.append({
"endpoint_index": idx,
"location": "status_code",
"field_name": None,
"violation_type": "wrong_status",
"description": (
f"{gold['method']} {gold['path']}: "
f"status_code is {cur.get('status_code')} "
f"but should be {gold.get('status_code')}"
),
"severity": 0.8,
})
# --- Request body and response body ---
for location in ("request_body", "response_body"):
cur_body: Dict[str, Any] = cur.get(location, {})
gold_body: Dict[str, Any] = gold.get(location, {})
# Missing required fields
for field, spec in gold_body.items():
if field not in cur_body:
violations.append({
"endpoint_index": idx,
"location": location,
"field_name": field,
"violation_type": "missing_field",
"description": (
f"{gold['method']} {gold['path']} {location}: "
f"required field '{field}' ({spec['type']}) is missing"
),
"severity": 1.0,
})
else:
# Wrong type
cur_type = cur_body[field].get("type")
gold_type = spec.get("type")
if cur_type != gold_type:
violations.append({
"endpoint_index": idx,
"location": location,
"field_name": field,
"violation_type": "wrong_type",
"description": (
f"{gold['method']} {gold['path']} {location}: "
f"field '{field}' has type '{cur_type}' "
f"but should be '{gold_type}'"
),
"severity": 0.9,
})
# Extra (forbidden) fields β fields in current but not in golden
for field in cur_body:
if field not in gold_body:
violations.append({
"endpoint_index": idx,
"location": location,
"field_name": field,
"violation_type": "extra_field",
"description": (
f"{gold['method']} {gold['path']} {location}: "
f"field '{field}' is present but should not be in the contract"
),
"severity": 0.7,
})
return violations
# ---------------------------------------------------------------------------
# Grader
# ---------------------------------------------------------------------------
def grade_episode(
current_endpoints: List[Dict[str, Any]],
golden_endpoints: List[Dict[str, Any]],
initial_violations: List[Dict[str, Any]],
) -> float:
"""
Score the agent's performance at the END of an episode.
Returns a float in [0.0, 1.0]:
1.0 β all violations fixed, no new ones introduced
0.0 β no improvement at all
intermediate β partial credit weighted by severity
Formula:
score = (weighted_fixed - weighted_introduced) / total_initial_weight
clamped to [0.0, 1.0]
"""
remaining = detect_violations(current_endpoints, golden_endpoints)
remaining_keys = _violation_keys(remaining)
initial_keys = _violation_keys(initial_violations)
# Violations that were present at start and are now gone = fixed
fixed = [v for v in initial_violations if _vkey(v) not in remaining_keys]
# Violations that are present now but weren't at start = newly introduced
introduced = [v for v in remaining if _vkey(v) not in initial_keys]
total_initial_weight = sum(v["severity"] for v in initial_violations)
if total_initial_weight == 0:
return 1.0 # spec was already clean
weighted_fixed = sum(v["severity"] for v in fixed)
weighted_introduced = sum(v["severity"] for v in introduced)
raw = (weighted_fixed - weighted_introduced) / total_initial_weight
return float(max(0.0, min(1.0, raw)))
def step_reward(
prev_violations: List[Dict[str, Any]],
new_violations: List[Dict[str, Any]],
initial_violations: List[Dict[str, Any]],
action_error: bool,
) -> float:
"""
Dense per-step reward signal.
+0.2 per violation resolved this step (weighted by severity)
-0.15 per new violation introduced
-0.05 for a malformed action (out-of-range index, bad field, etc.)
"""
if action_error:
return -0.05
prev_keys = _violation_keys(prev_violations)
new_keys = _violation_keys(new_violations)
fixed_this_step = [v for v in prev_violations if _vkey(v) not in new_keys]
introduced_this_step = [v for v in new_violations if _vkey(v) not in prev_keys]
reward = 0.0
for v in fixed_this_step:
reward += 0.2 * v["severity"]
for v in introduced_this_step:
reward -= 0.15 * v["severity"]
return round(reward, 4)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _vkey(v: Dict[str, Any]) -> tuple:
return (
v["endpoint_index"],
v["location"],
v.get("field_name"),
v["violation_type"],
)
def _violation_keys(violations: List[Dict[str, Any]]) -> set:
return {_vkey(v) for v in violations}
|