Spaces:
Sleeping
Sleeping
File size: 1,503 Bytes
999c3ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import unittest
from server.environment import SummarizationEnvironment
from models import SummarizationAction
class EnvironmentSmokeTest(unittest.TestCase):
def test_easy_episode_reaches_terminal_state_with_bounded_reward(self) -> None:
env = SummarizationEnvironment()
obs = env.reset(task_name="easy", seed=0)
self.assertEqual(obs.step_type, "summarize")
self.assertFalse(obs.done)
self.assertIsNotNone(obs.category)
self.assertIsNotNone(obs.source_type)
obs = env.step(SummarizationAction(response="Compact factual summary."))
self.assertEqual(obs.step_type, "answer")
self.assertFalse(obs.done)
answer = env.state.question or ""
obs = env.step(SummarizationAction(response=answer))
self.assertTrue(obs.done)
self.assertGreaterEqual(obs.reward, 0.0)
self.assertLessEqual(obs.reward, 1.0)
def test_hard_episode_uses_update_summary_stage(self) -> None:
env = SummarizationEnvironment()
obs = env.reset(task_name="hard", seed=0)
self.assertEqual(obs.step_type, "summarize")
self.assertEqual(obs.source_type, "scientific_paper")
obs = env.step(SummarizationAction(response="Initial summary."))
self.assertEqual(obs.step_type, "update_summary")
obs = env.step(SummarizationAction(response="Updated combined summary."))
self.assertEqual(obs.step_type, "answer")
if __name__ == "__main__":
unittest.main()
|