| """Unit tests for RecallTrace.""" |
|
|
| from __future__ import annotations |
|
|
| import unittest |
|
|
| from env.env import RecallTraceEnv |
| from grader.grader import evaluate_action_plan |
|
|
|
|
| class RecallTraceEnvTests(unittest.TestCase): |
| def test_phase1_plan_scores_high(self) -> None: |
| grade = evaluate_action_plan( |
| "phase1_direct_recall", |
| [ |
| {"type": "trace_lot", "lot_id": "LotA"}, |
| {"type": "inspect_node", "node_id": "warehouse"}, |
| {"type": "inspect_node", "node_id": "store1"}, |
| {"type": "inspect_node", "node_id": "store2"}, |
| {"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 100}, |
| {"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 50}, |
| {"type": "quarantine", "node_id": "store2", "lot_id": "LotA", "quantity": 20}, |
| {"type": "notify", "node_id": "all"}, |
| {"type": "finalize"}, |
| ], |
| ) |
| self.assertGreaterEqual(grade.score, 0.95) |
| self.assertTrue(grade.success) |
|
|
| def test_phase2_trace_reveals_relabels(self) -> None: |
| env = RecallTraceEnv(task_id="phase2_relabel_recall") |
| env.reset() |
| observation, reward, done, info = env.step({"type": "trace_lot", "lot_id": "LotA"}) |
| self.assertFalse(done) |
| self.assertGreater(reward, 0) |
| self.assertEqual(info["matched_lots"], ["LotA", "LotA_R1", "LotA_R2"]) |
| self.assertIn("store3", observation.trace_results["LotA"]["affected_nodes"]) |
|
|
| def test_phase3_mixed_inventory_requires_exact_quarantine(self) -> None: |
| env = RecallTraceEnv(task_id="phase3_mixed_shipments") |
| env.reset() |
| env.step({"type": "trace_lot", "lot_id": "LotA"}) |
| env.step({"type": "inspect_node", "node_id": "crossdock"}) |
| _, reward, _, info = env.step({"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 15}) |
| self.assertLess(reward, 0) |
| self.assertEqual(info["target_contaminated_quantity"], 12) |
|
|
| def test_phase3_full_plan_scores_high(self) -> None: |
| grade = evaluate_action_plan( |
| "phase3_mixed_shipments", |
| [ |
| {"type": "trace_lot", "lot_id": "LotA"}, |
| {"type": "inspect_node", "node_id": "warehouse"}, |
| {"type": "inspect_node", "node_id": "crossdock"}, |
| {"type": "inspect_node", "node_id": "store1"}, |
| {"type": "inspect_node", "node_id": "store2"}, |
| {"type": "inspect_node", "node_id": "store3"}, |
| {"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 30}, |
| {"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 12}, |
| {"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 10}, |
| {"type": "quarantine", "node_id": "store2", "lot_id": "LotBlend", "quantity": 8}, |
| {"type": "quarantine", "node_id": "store3", "lot_id": "LotBlend", "quantity": 4}, |
| {"type": "notify", "node_id": "all"}, |
| {"type": "finalize"}, |
| ], |
| ) |
| self.assertGreaterEqual(grade.score, 0.95) |
| self.assertTrue(grade.final_info["all_affected_stock_quarantined"]) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|