File size: 11,387 Bytes
4ec75cf
 
 
 
 
 
 
 
 
6f95f2a
 
4ec75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f95f2a
 
 
 
4ec75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f95f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ec75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# server/validation/validator.py
# 3-stage pre-action validation: Schema β†’ Domain β†’ Consistency.
# IMPORTANT: Validator should HELP agents, not trap them.
# - Auto-coerce types where possible (string "8.5" β†’ float 8.5)
# - Only hard-reject truly unrecoverable actions (wrong domain)
# - Silently truncate oversized fields instead of rejecting
# - Rich hints so agent can self-correct on next step

from typing import Dict, Tuple
from functools import lru_cache
import json

VALID_VULN_TYPES = {
    'sql_injection', 'xss', 'idor', 'hardcoded_secret', 'missing_auth',
    'jwt_misuse', 'path_traversal', 'ssrf', 'rate_limit_missing', 'xxe'
}
VALID_SEVERITIES = {'critical', 'high', 'medium', 'low'}
VALID_RISK_LEVELS = {'critical', 'high', 'medium', 'low'}

# Which actions belong to which domain
DOMAIN_ACTIONS = {
    'security':   {'identify_vulnerability', 'propose_fix', 'revise_fix'},
    'dependency': {'flag_outdated', 'resolve_conflict', 'migrate_api', 'validate_tree'},
    'clinical':   {'detect_gap', 'rank_issues', 'order_steps'},
}

# Required fields and their types for each action
ACTION_SCHEMAS = {
    'identify_vulnerability': {
        'vuln_type': str,
        'cvss_score': (int, float),
        'severity': str,
    },
    'propose_fix': {
        'fix_code': str,
        'explanation': str,
    },
    'revise_fix': {
        'fix_code': str,
        'addressed_feedback': str,
    },
    'flag_outdated': {
        'packages': dict,
        # deprecated_api and replacement are optional β€” handled below
    },
    'resolve_conflict': {
        'packages': dict,
        'reasoning': str,
    },
    'migrate_api': {
        'completed_items': list,
        'code_changes': dict,
    },
    'validate_tree': {
        'completed_items': list,
    },
    'detect_gap': {
        'missing_steps': list,
        'risk_level': str,
    },
    'rank_issues': {
        'priority_order': list,
    },
    'order_steps': {
        'recovery_steps': list,
    },
}

# Fields that are optional (won't cause hard rejection if missing)
OPTIONAL_FIELDS = {
    'flag_outdated': {'deprecated_api', 'replacement'},
    'identify_vulnerability': {'affected_line'},
}


def _coerce(action: Dict, schema: Dict) -> Dict:
    """Try to coerce field types before validating. Modifies action in-place.
    
    This is critical for model compatibility β€” different LLMs output
    numbers as strings, lists as comma-separated strings, etc.
    """
    for field, expected_type in schema.items():
        if field not in action:
            continue
        val = action[field]
        # Already correct type
        if isinstance(val, expected_type):
            continue
        # Try coercions
        try:
            target = expected_type[0] if isinstance(expected_type, tuple) else expected_type
            if target == float:
                action[field] = float(val)
            elif target == int:
                action[field] = int(val)
            elif target == str and not isinstance(val, str):
                action[field] = str(val)
            elif target == list and isinstance(val, str):
                # Try JSON parse first, then comma split
                try:
                    import json as _j
                    parsed = _j.loads(val)
                    if isinstance(parsed, list):
                        action[field] = parsed
                except Exception:
                    action[field] = [x.strip(' "\'') for x in val.strip('[]').split(',') if x.strip()]
            elif target == dict and isinstance(val, str):
                import json as _j
                action[field] = _j.loads(val)
        except Exception:
            pass  # Leave as-is; domain check will catch real problems
    return action


def validate_action(action: Dict, session) -> Tuple[bool, Dict]:
    """3-stage validation. Returns (is_valid, feedback_observation).

    Philosophy: be lenient on format (coerce types), strict on cross-domain actions.
    An action in the wrong domain = hard reject.
    An action with slightly wrong types = coerce and pass through.
    """
    atype = action.get('action_type', '')

    # ── Stage 1: Is this a known action type? ──
    all_valid = set(ACTION_SCHEMAS.keys())
    if atype not in all_valid:
        return False, _fb(
            'invalid_action_type',
            f'Unknown action_type: {repr(atype)}',
            session,
            hint=f'Valid actions for {session.task_type}: {sorted(DOMAIN_ACTIONS.get(session.task_type, []))}',
        )

    # ── Cross-domain check FIRST (before coercion) ──
    domain_valid = DOMAIN_ACTIONS.get(session.task_type, set())
    if atype not in domain_valid:
        return False, _fb(
            'wrong_domain_action',
            f'{repr(atype)} is not valid for task_type={repr(session.task_type)}',
            session,
            hint=f'Valid actions: {sorted(domain_valid)}',
        )

    # ── Coerce types before schema check (be helpful to all models) ──
    schema = ACTION_SCHEMAS.get(atype, {})
    action = _coerce(action, schema)

    # ── Stage 2: Check required fields are present ──
    optional = OPTIONAL_FIELDS.get(atype, set())
    required_fields = [f for f in schema if f not in optional]
    missing = [f for f in required_fields if f not in action]
    if missing:
        return False, _fb(
            'missing_fields',
            f'Missing required fields: {missing}',
            session,
            hint=f'Required for {atype}: {required_fields}',
        )

    # ── Stage 3: Domain value validation ──
    errs = _domain_check(action, atype)
    if errs:
        return False, _fb(
            'domain_error',
            f'Invalid field values: {errs}',
            session,
            hint=_domain_hint(atype, errs),
        )

    # ── Stage 4: Consistency check ──
    cons = _consistency_check(action, atype, session)
    if cons:
        return False, _fb('consistency_error', cons['message'], session, hint=cons['hint'])

    return True, {}


@lru_cache(maxsize=1024)
def _cached_domain_errors(action_json: str, atype: str) -> list:
    """Pure domain check logic that can be safely cached."""
    action = json.loads(action_json)
    errors = []

    if atype == 'identify_vulnerability':
        vt = action.get('vuln_type', '')
        if vt not in VALID_VULN_TYPES:
            errors.append({'field': 'vuln_type', 'value': vt, 'allowed': sorted(VALID_VULN_TYPES)})
        try:
            cvss = float(action.get('cvss_score', -1))
            if not (0.0 <= cvss <= 10.0):
                errors.append({'field': 'cvss_score', 'value': cvss, 'allowed': '0.0 to 10.0'})
        except (TypeError, ValueError):
            errors.append({'field': 'cvss_score', 'value': action.get('cvss_score'), 'allowed': '0.0 to 10.0'})
        sev = action.get('severity', '')
        if sev not in VALID_SEVERITIES:
            errors.append({'field': 'severity', 'value': sev, 'allowed': sorted(VALID_SEVERITIES)})

    elif atype == 'detect_gap':
        rl = action.get('risk_level', '')
        if rl not in VALID_RISK_LEVELS:
            errors.append({'field': 'risk_level', 'value': rl, 'allowed': sorted(VALID_RISK_LEVELS)})

    elif atype == 'resolve_conflict':
        pkgs = action.get('packages', {})
        if not isinstance(pkgs, dict) or len(pkgs) == 0:
            errors.append({'field': 'packages', 'issue': 'must be a non-empty dict of {package: version}'})

    elif atype == 'migrate_api':
        items = action.get('completed_items', [])
        changes = action.get('code_changes', {})
        if not isinstance(items, list) or len(items) == 0:
            errors.append({'field': 'completed_items', 'issue': 'must be a non-empty list of break IDs'})
        if not isinstance(changes, dict):
            errors.append({'field': 'code_changes', 'issue': 'must be a dict of {break_id: fix_description}'})

    return errors


def _domain_check(action: Dict, atype: str) -> list:
    """Check values are within allowed ranges/enums. Returns list of error dicts."""
    # Handle mutations first (cannot be purely cached)
    if atype in ('propose_fix', 'revise_fix'):
        fix = action.get('fix_code', '')
        if len(fix) > 2000:
            # Silently truncate instead of rejecting β€” don't penalize verbose agents
            action['fix_code'] = fix[:2000]

    # Use cached pure function for validation
    try:
        action_json = json.dumps(action, sort_keys=True)
        return _cached_domain_errors(action_json, atype)
    except Exception:
        # Fallback if not serializable
        return _cached_domain_errors(json.dumps({'dummy': True}), atype)


def _domain_hint(atype: str, errors: list) -> str:
    """Generate a helpful hint for domain errors."""
    fields = [e.get('field', '') for e in errors]
    if 'vuln_type' in fields:
        return "vuln_type must be one of: sql_injection, xss, idor, hardcoded_secret, missing_auth, jwt_misuse, path_traversal, ssrf, rate_limit_missing, xxe"
    if 'severity' in fields:
        return "severity must be one of: critical, high, medium, low"
    if 'risk_level' in fields:
        return "risk_level must be one of: critical, high, medium, low"
    if 'cvss_score' in fields:
        return "cvss_score must be a float between 0.0 and 10.0"
    return f"Check field values for: {fields}"


def _consistency_check(action: Dict, atype: str, session) -> dict:
    """Check that action makes sense given session history."""
    hist_types = [h.get('action_type') for h in session.history]

    if atype == 'revise_fix' and 'propose_fix' not in hist_types:
        return {
            'message': 'Cannot call revise_fix before propose_fix',
            'hint': 'Call propose_fix first, then revise_fix if you get reviewer feedback'
        }

    if atype == 'rank_issues' and 'detect_gap' not in hist_types:
        return {
            'message': 'Cannot call rank_issues before detect_gap',
            'hint': 'Call detect_gap first, then rank_issues'
        }

    if atype == 'order_steps' and 'detect_gap' not in hist_types:
        return {
            'message': 'Cannot call order_steps before detect_gap',
            'hint': 'Call detect_gap first, then rank_issues, then order_steps'
        }

    # Reject identical resolve_conflict proposals (infinite loop prevention)
    if atype == 'resolve_conflict':
        for prev in session.history:
            if (prev.get('action_type') == 'resolve_conflict' and
                    prev.get('packages') == action.get('packages', {})):
                return {
                    'message': 'Identical version proposal already submitted β€” this combination was rejected',
                    'hint': 'Try different package versions. Check the compatibility_matrix in the observation.'
                }

    return {}


def _fb(error_type: str, message: str, session, **kwargs) -> Dict:
    """Build a feedback observation for validation failures."""
    obs = {
        'validation_failed': True,
        'error_type': error_type,
        'message': message,
        'turn': session.step_count,
        'task_type': session.task_type,
        'task_id': getattr(session, 'task_id', ''),
        'available_actions': sorted(DOMAIN_ACTIONS.get(session.task_type, [])),
    }
    obs.update(kwargs)
    return obs