File size: 2,208 Bytes
e4f3d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import annotations

from commitguard_env.models import CommitGuardAction
from commitguard_env.reward import compute_reward


def test_reward_true_positive_correct_cwe_and_exploit_match() -> None:
    a = CommitGuardAction(
        action_type="verdict",
        is_vulnerable=True,
        vuln_type="CWE-89",
        exploit_sketch="This is classic SQL injection: SELECT ... WHERE ... concat",
    )
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql", "select", "where", "concat", "injection"]},
        context_requests=0,
    )
    assert r == 2.0


def test_reward_true_positive_wrong_cwe() -> None:
    a = CommitGuardAction(action_type="verdict", is_vulnerable=True, vuln_type="CWE-79", exploit_sketch="sql injection")
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql"]},
        context_requests=0,
    )
    assert r == 1.5  # +1.0 verdict, +0.5 exploit match, no CWE bonus


def test_reward_false_positive() -> None:
    a = CommitGuardAction(action_type="verdict", is_vulnerable=True, vuln_type="CWE-89", exploit_sketch="sql")
    r = compute_reward(
        action=a,
        is_vulnerable=False,
        cwe=None,
        target_file=None,
        cwe_keywords={},
        context_requests=0,
    )
    assert r == -1.0


def test_reward_false_negative() -> None:
    a = CommitGuardAction(action_type="verdict", is_vulnerable=False, vuln_type="NONE", exploit_sketch="")
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql"]},
        context_requests=0,
    )
    assert r == -0.5


def test_reward_malformed_action_penalty_no_crash() -> None:
    a = CommitGuardAction(action_type="analyze", raw_action="<<<", parse_error="bad_xml")
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql"]},
        context_requests=0,
    )
    assert r == -0.5