File size: 7,205 Bytes
a4f74f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
Reward functions for GRPO training (v2 β€” plan-based).

The model outputs a FULL TEST PLAN (JSON array of actions).
Each reward function creates a FRESH environment, executes ALL actions,
and scores the result.

Three reward signals:
1. format_reward    β€” Valid JSON array with 3+ diverse actions? (+2 / -2)
2. plan_reward      β€” Execute plan, score on bugs + coverage + efficiency (0 to ~8)
3. diversity_reward β€” Variety of methods, endpoints, and request patterns (+0 to +2)
"""

import re
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from models import APITestAction, HTTPMethod
from server.environment import APITestEnvironment
from .prompts import parse_test_plan


def format_reward_fn(completions: list[str], **kwargs) -> list[float]:
    """Reward for valid JSON test plan format.

    +2.0 if output has 5+ diverse actions (a real plan)
    +1.0 if output has 3-4 actions (minimal plan)
    +0.0 if output has 1-2 actions (barely valid)
    -2.0 if it can't be parsed at all

    Also penalizes if all actions are identical.
    """
    rewards = []
    for text in completions:
        actions = parse_test_plan(text)
        if not actions:
            rewards.append(-2.0)
            continue

        n = len(actions)

        # Check diversity β€” are the actions actually different?
        unique_pairs = set()
        for a in actions:
            m = a.method.value if hasattr(a.method, "value") else str(a.method)
            ep = re.sub(r'/\d+', '/{id}', a.endpoint)
            unique_pairs.add((m, ep))

        diversity_ratio = len(unique_pairs) / max(n, 1)

        if n >= 5 and diversity_ratio >= 0.5:
            rewards.append(2.0)
        elif n >= 3:
            rewards.append(1.0)
        elif n >= 1:
            rewards.append(0.0)
        else:
            rewards.append(-2.0)

        # Penalty if all actions are the same
        if len(unique_pairs) <= 1 and n > 1:
            rewards[-1] = -1.0

    return rewards


def plan_reward_fn(completions: list[str], **kwargs) -> list[float]:
    """Execute the full test plan in a FRESH environment and return a balanced score.

    Score components:
    - Bug discovery:  min(bugs_found, 5) * 1.0  (capped at 5.0 to not dominate)
    - Coverage:       (coverage_pct / 100) * 2.0 (up to 2.0)
    - Efficiency:     if bugs > 0: +0.5 per bug found in first 10 actions
    - Crash penalty:  -0.1 per action that caused a 500 error

    Total range: roughly -2 to +8

    Each completion gets its OWN fresh environment β€” no state pollution.
    """
    prompts_meta = kwargs.get("prompts_meta", [])
    rewards = []

    for i, text in enumerate(completions):
        actions = parse_test_plan(text)
        if not actions:
            rewards.append(-1.0)
            continue

        # Get episode seed and task
        meta = prompts_meta[i % len(prompts_meta)] if prompts_meta else {}
        seed = meta.get("seed", 42)
        task_id = meta.get("task_id", "basic_validation")

        # Create a FRESH environment
        env = APITestEnvironment()
        env.reset(seed=seed, task_id=task_id)

        # Execute all actions, track results
        crashes = 0
        step_rewards = []
        for action in actions:
            try:
                obs = env.step(action)
                step_rewards.append(obs.reward or 0.0)
                if obs.status_code >= 500:
                    crashes += 1
            except Exception:
                step_rewards.append(0.0)
                crashes += 1

        state = env.state
        coverage = state.coverage_pct

        # Component 1: Bug discovery (capped to prevent domination)
        bug_score = min(state.bugs_found, 5) * 1.0

        # Component 2: Coverage (proportional, up to 2.0)
        coverage_score = (coverage / 100) * 2.0

        # Component 3: Efficiency β€” finding bugs early is better
        early_bug_bonus = 0.0
        early_steps = step_rewards[:10]
        for r in early_steps:
            if r > 0.2:  # High reward step = likely found a bug
                early_bug_bonus += 0.3

        # Component 4: Crash penalty
        crash_penalty = crashes * -0.1

        # Component 5: Step reward sum (small weight β€” mainly for gradient signal)
        step_sum = sum(step_rewards) * 0.2

        total = bug_score + coverage_score + early_bug_bonus + crash_penalty + step_sum
        rewards.append(round(total, 4))

    return rewards


def diversity_reward_fn(completions: list[str], **kwargs) -> list[float]:
    """Reward for diverse test plans β€” varied methods, endpoints, and strategies.

    Components:
    - Method variety:     up to +0.5 (using GET/POST/PUT/DELETE)
    - Endpoint variety:   up to +0.5 (testing different endpoints)
    - Strategy variety:   up to +0.5 (auth + invalid input + boundary + injection patterns)
    - Repetition penalty: up to -0.5
    """
    rewards = []
    for text in completions:
        actions = parse_test_plan(text)
        if not actions:
            rewards.append(0.0)
            continue

        methods = set()
        endpoints = set()
        unique_pairs = set()
        has_auth = False
        has_invalid_input = False
        has_boundary = False
        has_injection = False
        has_nonexistent_id = False

        for a in actions:
            m = a.method.value if hasattr(a.method, "value") else str(a.method)
            methods.add(m)
            norm_ep = re.sub(r'/\d+', '/{id}', a.endpoint)
            endpoints.add(norm_ep)
            unique_pairs.add((m, norm_ep))

            # Detect testing strategies
            if a.endpoint == "/auth/login":
                has_auth = True
            if a.body and not a.body.get("title") and a.method.value == "POST":
                has_invalid_input = True
            qp = a.query_params or {}
            if any(isinstance(v, (int, float)) and v < 0 for v in qp.values()):
                has_boundary = True
            if any(isinstance(v, (int, float)) and v > 10000 for v in qp.values()):
                has_boundary = True
            if a.body and any("DROP" in str(v).upper() or "script" in str(v).lower()
                              for v in (a.body or {}).values()):
                has_injection = True
            if re.search(r'/\d{4,}', a.endpoint):
                has_nonexistent_id = True

        # Method variety (max 4 methods = +0.5)
        method_score = min(len(methods) / 4, 1.0) * 0.5

        # Endpoint variety (max 7 endpoints = +0.5)
        endpoint_score = min(len(endpoints) / 7, 1.0) * 0.5

        # Strategy variety (each strategy = +0.1, max +0.5)
        strategies = sum([has_auth, has_invalid_input, has_boundary, has_injection, has_nonexistent_id])
        strategy_score = min(strategies * 0.1, 0.5)

        # Repetition penalty
        if len(actions) > 0:
            repeat_count = len(actions) - len(unique_pairs)
            repetition_penalty = min(repeat_count / len(actions), 1.0) * -0.5
        else:
            repetition_penalty = 0.0

        total = method_score + endpoint_score + strategy_score + repetition_penalty
        rewards.append(round(total, 3))

    return rewards