File size: 1,595 Bytes
30cf758
 
 
 
 
 
 
 
 
 
9b71d1b
30cf758
 
9b71d1b
30cf758
 
 
9b71d1b
30cf758
 
 
 
 
9b71d1b
30cf758
 
9b71d1b
30cf758
 
 
9b71d1b
30cf758
 
 
 
 
 
9b71d1b
30cf758
 
9b71d1b
30cf758
 
 
9b71d1b
 
30cf758
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import unittest

from server.reward import compute_reward


class TestReward(unittest.TestCase):
    def test_submit_query_perfect_reward(self):
        reward = compute_reward(
            action_type="submit_query",
            query_result={"success": True},
            grade_score=0.999,
            steps_taken=1,
            max_steps=10,
            previous_best_score=0.001,
            schema_tables=["t1", "t2"],
            submitted_query="SELECT * FROM t1 JOIN t2",
        )
        self.assertAlmostEqual(reward.value, 0.999, places=4)

    def test_reset_query_penalty(self):
        reward = compute_reward(
            action_type="reset_query",
            query_result=None,
            grade_score=0.001,
            steps_taken=1,
            max_steps=10,
            previous_best_score=0.001,
            schema_tables=[],
            submitted_query=None,
        )
        self.assertAlmostEqual(reward.value, 0.001, places=4)

    def test_inspect_schema_urgency_penalty(self):
        # Make steps_remaining <= 2 and grade_score < 0.5 to trigger urgency penalty.
        reward = compute_reward(
            action_type="inspect_schema",
            query_result=None,
            grade_score=0.001,
            steps_taken=8,
            max_steps=9,
            previous_best_score=0.001,
            schema_tables=[],
            submitted_query=None,
        )
        # syntax_progress=0.01, penalty=0.03 => total_raw=-0.02, clamped to strict min
        self.assertAlmostEqual(reward.value, 0.001, places=4)


if __name__ == "__main__":
    unittest.main()