File size: 2,334 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

from commitguard_env.models import CommitGuardAction
from commitguard_env.reward import compute_reward


def test_reward_true_positive_correct_cwe_and_exploit_match() -> None:
    a = CommitGuardAction(
        action_type="verdict",
        is_vulnerable=True,
        vuln_type="CWE-89",
        exploit_sketch="This is classic SQL injection: SELECT ... WHERE ... concat",
    )
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql", "select", "where", "concat", "injection"]},
        context_requests=0,
    )
    assert r == 2.0


def test_reward_true_positive_wrong_cwe_same_family() -> None:
    # CWE-79 and CWE-89 are both in the "injection" family -> family bonus = 0.5 * 0.5
    a = CommitGuardAction(action_type="verdict", is_vulnerable=True, vuln_type="CWE-79", exploit_sketch="sql injection")
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql"]},
        context_requests=0,
    )
    # 1.0 (TP) + 0.25 (family match: 0.5 * 0.5) + 0.5 (keyword: 1/1) = 1.75
    assert r == 1.75


def test_reward_false_positive() -> None:
    a = CommitGuardAction(action_type="verdict", is_vulnerable=True, vuln_type="CWE-89", exploit_sketch="sql")
    r = compute_reward(
        action=a,
        is_vulnerable=False,
        cwe=None,
        target_file=None,
        cwe_keywords={},
        context_requests=0,
    )
    assert r == -1.0


def test_reward_false_negative() -> None:
    a = CommitGuardAction(action_type="verdict", is_vulnerable=False, vuln_type="NONE", exploit_sketch="")
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql"]},
        context_requests=0,
    )
    assert r == -0.5


def test_reward_malformed_action_penalty_no_crash() -> None:
    a = CommitGuardAction(action_type="analyze", raw_action="<<<", parse_error="bad_xml")
    r = compute_reward(
        action=a,
        is_vulnerable=True,
        cwe="CWE-89",
        target_file="db.py",
        cwe_keywords={"CWE-89": ["sql"]},
        context_requests=0,
    )
    assert r == -0.5