Spaces:
Running
Running
File size: 1,595 Bytes
30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 9b71d1b 30cf758 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import unittest
from server.reward import compute_reward
class TestReward(unittest.TestCase):
def test_submit_query_perfect_reward(self):
reward = compute_reward(
action_type="submit_query",
query_result={"success": True},
grade_score=0.999,
steps_taken=1,
max_steps=10,
previous_best_score=0.001,
schema_tables=["t1", "t2"],
submitted_query="SELECT * FROM t1 JOIN t2",
)
self.assertAlmostEqual(reward.value, 0.999, places=4)
def test_reset_query_penalty(self):
reward = compute_reward(
action_type="reset_query",
query_result=None,
grade_score=0.001,
steps_taken=1,
max_steps=10,
previous_best_score=0.001,
schema_tables=[],
submitted_query=None,
)
self.assertAlmostEqual(reward.value, 0.001, places=4)
def test_inspect_schema_urgency_penalty(self):
# Make steps_remaining <= 2 and grade_score < 0.5 to trigger urgency penalty.
reward = compute_reward(
action_type="inspect_schema",
query_result=None,
grade_score=0.001,
steps_taken=8,
max_steps=9,
previous_best_score=0.001,
schema_tables=[],
submitted_query=None,
)
# syntax_progress=0.01, penalty=0.03 => total_raw=-0.02, clamped to strict min
self.assertAlmostEqual(reward.value, 0.001, places=4)
if __name__ == "__main__":
unittest.main()
|