File size: 4,752 Bytes
bf2775e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83c8ad
 
bf2775e
 
 
b83c8ad
 
bf2775e
 
 
b83c8ad
 
bf2775e
 
 
b83c8ad
 
bf2775e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83c8ad
bf2775e
b83c8ad
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

import pytest

from sql_query_reviewer.models import GroundTruthIssue, SQLReviewAction
from server.reward import compute_reward


def _action(action_type: str, confidence: float = 0.5) -> SQLReviewAction:
    if action_type == "identify_issue":
        return SQLReviewAction(
            action_type="identify_issue",
            issue_category="syntax",
            issue_description="some issue",
            confidence=confidence,
        )
    if action_type == "suggest_fix":
        return SQLReviewAction(
            action_type="suggest_fix",
            suggested_fix="SELECT 1;",
            confidence=confidence,
        )
    return SQLReviewAction(action_type=action_type, confidence=confidence)


def _issue(severity: float = 0.35) -> GroundTruthIssue:
    return GroundTruthIssue(
        id="test_issue_001",
        category="syntax",
        description="A test issue.",
        severity=severity,
        fix="SELECT 1;",
        keywords=["test"],
    )


# ── identify_issue ────────────────────────────────────────────────────────────

def test_identify_issue_duplicate_returns_small_penalty() -> None:
    assert compute_reward(_action("identify_issue"), _issue(), duplicate_issue=True) == pytest.approx(-0.02)


def test_identify_issue_no_match_returns_penalty() -> None:
    assert compute_reward(_action("identify_issue"), None) == pytest.approx(-0.1)


def test_identify_issue_match_no_fix_zero_confidence() -> None:
    # base_reward = min(0.35, 0.35) = 0.35; fix_bonus = 0; confidence_bonus = 0
    # order_bonus = 0.04 * (1/(0+1)) = 0.04 β†’ total = 0.39
    assert compute_reward(_action("identify_issue", confidence=0.0), _issue(0.35)) == pytest.approx(0.39)


def test_identify_issue_match_no_fix_full_confidence() -> None:
    # base=0.35 + confidence_bonus=min(0.05, 1.0*0.35*0.08)=0.028 + order_bonus=0.04 β†’ 0.418
    assert compute_reward(_action("identify_issue", confidence=1.0), _issue(0.35)) == pytest.approx(0.418)


def test_identify_issue_match_with_fix_zero_confidence() -> None:
    # base=0.35 + fix_bonus=0.08 + order_bonus=0.04 = 0.47, capped at 0.45
    assert compute_reward(_action("identify_issue", confidence=0.0), _issue(0.35), fix_valid=True) == pytest.approx(0.45)


def test_identify_issue_high_severity_capped_at_035_base() -> None:
    # min(0.9, 0.35) = 0.35 + order_bonus=0.04 = 0.39
    assert compute_reward(_action("identify_issue", confidence=0.0), _issue(severity=0.9)) == pytest.approx(0.39)


# ── suggest_fix ───────────────────────────────────────────────────────────────

def test_suggest_fix_without_previous_issue_is_penalized() -> None:
    assert compute_reward(_action("suggest_fix"), None, has_previous_issue=False) == pytest.approx(-0.05)


def test_suggest_fix_with_previous_issue_invalid_fix() -> None:
    assert compute_reward(_action("suggest_fix"), _issue(), has_previous_issue=True, fix_valid=False) == pytest.approx(0.0)


def test_suggest_fix_with_previous_issue_valid_fix() -> None:
    assert compute_reward(_action("suggest_fix"), _issue(), has_previous_issue=True, fix_valid=True) == pytest.approx(0.1)


# ── approve ───────────────────────────────────────────────────────────────────

def test_approve_all_issues_found_gives_positive_reward() -> None:
    assert compute_reward(_action("approve"), None, remaining_unfound=0) == pytest.approx(0.2)


def test_approve_one_issue_missed_gives_penalty() -> None:
    assert compute_reward(_action("approve"), None, remaining_unfound=1) == pytest.approx(-0.15)


def test_approve_many_issues_missed_floors_at_negative_one() -> None:
    # -0.15 * 7 = -1.05 β†’ floored at -1.0
    assert compute_reward(_action("approve"), None, remaining_unfound=7) == pytest.approx(-1.0)


# ── request_more_context ──────────────────────────────────────────────────────

def test_request_more_context_returns_zero() -> None:
    # No schema_available β†’ returns 0.0
    assert compute_reward(_action("request_more_context"), None) == pytest.approx(0.0)


def test_request_more_context_with_schema_returns_penalty() -> None:
    # schema_available=True β†’ returns -0.03
    assert compute_reward(_action("request_more_context"), None, schema_available=True) == pytest.approx(-0.03)