Spaces:
Running
Running
File size: 7,205 Bytes
a4f74f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """
Reward functions for GRPO training (v2 β plan-based).
The model outputs a FULL TEST PLAN (JSON array of actions).
Each reward function creates a FRESH environment, executes ALL actions,
and scores the result.
Three reward signals:
1. format_reward β Valid JSON array with 3+ diverse actions? (+2 / -2)
2. plan_reward β Execute plan, score on bugs + coverage + efficiency (0 to ~8)
3. diversity_reward β Variety of methods, endpoints, and request patterns (+0 to +2)
"""
import re
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from models import APITestAction, HTTPMethod
from server.environment import APITestEnvironment
from .prompts import parse_test_plan
def format_reward_fn(completions: list[str], **kwargs) -> list[float]:
"""Reward for valid JSON test plan format.
+2.0 if output has 5+ diverse actions (a real plan)
+1.0 if output has 3-4 actions (minimal plan)
+0.0 if output has 1-2 actions (barely valid)
-2.0 if it can't be parsed at all
Also penalizes if all actions are identical.
"""
rewards = []
for text in completions:
actions = parse_test_plan(text)
if not actions:
rewards.append(-2.0)
continue
n = len(actions)
# Check diversity β are the actions actually different?
unique_pairs = set()
for a in actions:
m = a.method.value if hasattr(a.method, "value") else str(a.method)
ep = re.sub(r'/\d+', '/{id}', a.endpoint)
unique_pairs.add((m, ep))
diversity_ratio = len(unique_pairs) / max(n, 1)
if n >= 5 and diversity_ratio >= 0.5:
rewards.append(2.0)
elif n >= 3:
rewards.append(1.0)
elif n >= 1:
rewards.append(0.0)
else:
rewards.append(-2.0)
# Penalty if all actions are the same
if len(unique_pairs) <= 1 and n > 1:
rewards[-1] = -1.0
return rewards
def plan_reward_fn(completions: list[str], **kwargs) -> list[float]:
"""Execute the full test plan in a FRESH environment and return a balanced score.
Score components:
- Bug discovery: min(bugs_found, 5) * 1.0 (capped at 5.0 to not dominate)
- Coverage: (coverage_pct / 100) * 2.0 (up to 2.0)
- Efficiency: if bugs > 0: +0.5 per bug found in first 10 actions
- Crash penalty: -0.1 per action that caused a 500 error
Total range: roughly -2 to +8
Each completion gets its OWN fresh environment β no state pollution.
"""
prompts_meta = kwargs.get("prompts_meta", [])
rewards = []
for i, text in enumerate(completions):
actions = parse_test_plan(text)
if not actions:
rewards.append(-1.0)
continue
# Get episode seed and task
meta = prompts_meta[i % len(prompts_meta)] if prompts_meta else {}
seed = meta.get("seed", 42)
task_id = meta.get("task_id", "basic_validation")
# Create a FRESH environment
env = APITestEnvironment()
env.reset(seed=seed, task_id=task_id)
# Execute all actions, track results
crashes = 0
step_rewards = []
for action in actions:
try:
obs = env.step(action)
step_rewards.append(obs.reward or 0.0)
if obs.status_code >= 500:
crashes += 1
except Exception:
step_rewards.append(0.0)
crashes += 1
state = env.state
coverage = state.coverage_pct
# Component 1: Bug discovery (capped to prevent domination)
bug_score = min(state.bugs_found, 5) * 1.0
# Component 2: Coverage (proportional, up to 2.0)
coverage_score = (coverage / 100) * 2.0
# Component 3: Efficiency β finding bugs early is better
early_bug_bonus = 0.0
early_steps = step_rewards[:10]
for r in early_steps:
if r > 0.2: # High reward step = likely found a bug
early_bug_bonus += 0.3
# Component 4: Crash penalty
crash_penalty = crashes * -0.1
# Component 5: Step reward sum (small weight β mainly for gradient signal)
step_sum = sum(step_rewards) * 0.2
total = bug_score + coverage_score + early_bug_bonus + crash_penalty + step_sum
rewards.append(round(total, 4))
return rewards
def diversity_reward_fn(completions: list[str], **kwargs) -> list[float]:
"""Reward for diverse test plans β varied methods, endpoints, and strategies.
Components:
- Method variety: up to +0.5 (using GET/POST/PUT/DELETE)
- Endpoint variety: up to +0.5 (testing different endpoints)
- Strategy variety: up to +0.5 (auth + invalid input + boundary + injection patterns)
- Repetition penalty: up to -0.5
"""
rewards = []
for text in completions:
actions = parse_test_plan(text)
if not actions:
rewards.append(0.0)
continue
methods = set()
endpoints = set()
unique_pairs = set()
has_auth = False
has_invalid_input = False
has_boundary = False
has_injection = False
has_nonexistent_id = False
for a in actions:
m = a.method.value if hasattr(a.method, "value") else str(a.method)
methods.add(m)
norm_ep = re.sub(r'/\d+', '/{id}', a.endpoint)
endpoints.add(norm_ep)
unique_pairs.add((m, norm_ep))
# Detect testing strategies
if a.endpoint == "/auth/login":
has_auth = True
if a.body and not a.body.get("title") and a.method.value == "POST":
has_invalid_input = True
qp = a.query_params or {}
if any(isinstance(v, (int, float)) and v < 0 for v in qp.values()):
has_boundary = True
if any(isinstance(v, (int, float)) and v > 10000 for v in qp.values()):
has_boundary = True
if a.body and any("DROP" in str(v).upper() or "script" in str(v).lower()
for v in (a.body or {}).values()):
has_injection = True
if re.search(r'/\d{4,}', a.endpoint):
has_nonexistent_id = True
# Method variety (max 4 methods = +0.5)
method_score = min(len(methods) / 4, 1.0) * 0.5
# Endpoint variety (max 7 endpoints = +0.5)
endpoint_score = min(len(endpoints) / 7, 1.0) * 0.5
# Strategy variety (each strategy = +0.1, max +0.5)
strategies = sum([has_auth, has_invalid_input, has_boundary, has_injection, has_nonexistent_id])
strategy_score = min(strategies * 0.1, 0.5)
# Repetition penalty
if len(actions) > 0:
repeat_count = len(actions) - len(unique_pairs)
repetition_penalty = min(repeat_count / len(actions), 1.0) * -0.5
else:
repetition_penalty = 0.0
total = method_score + endpoint_score + strategy_score + repetition_penalty
rewards.append(round(total, 3))
return rewards
|