| """Unit tests for RecallTrace.""" | |
| from __future__ import annotations | |
| import unittest | |
| from env.env import RecallTraceEnv | |
| from grader.grader import evaluate_action_plan | |
| class RecallTraceEnvTests(unittest.TestCase): | |
| def test_phase1_plan_scores_high(self) -> None: | |
| grade = evaluate_action_plan( | |
| "phase1_direct_recall", | |
| [ | |
| {"type": "trace_lot", "lot_id": "LotA"}, | |
| {"type": "inspect_node", "node_id": "warehouse"}, | |
| {"type": "inspect_node", "node_id": "store1"}, | |
| {"type": "inspect_node", "node_id": "store2"}, | |
| {"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 100}, | |
| {"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 50}, | |
| {"type": "quarantine", "node_id": "store2", "lot_id": "LotA", "quantity": 20}, | |
| {"type": "notify", "node_id": "all"}, | |
| {"type": "finalize"}, | |
| ], | |
| ) | |
| self.assertGreaterEqual(grade.score, 0.95) | |
| self.assertTrue(grade.success) | |
| def test_phase2_trace_reveals_relabels(self) -> None: | |
| env = RecallTraceEnv(task_id="phase2_relabel_recall") | |
| env.reset() | |
| observation, reward, done, info = env.step({"type": "trace_lot", "lot_id": "LotA"}) | |
| self.assertFalse(done) | |
| self.assertGreater(reward, 0) | |
| self.assertEqual(info["matched_lots"], ["LotA", "LotA_R1", "LotA_R2"]) | |
| self.assertIn("store3", observation.trace_results["LotA"]["affected_nodes"]) | |
| def test_phase3_mixed_inventory_requires_exact_quarantine(self) -> None: | |
| env = RecallTraceEnv(task_id="phase3_mixed_shipments") | |
| env.reset() | |
| env.step({"type": "trace_lot", "lot_id": "LotA"}) | |
| env.step({"type": "inspect_node", "node_id": "crossdock"}) | |
| _, reward, _, info = env.step({"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 15}) | |
| self.assertLess(reward, 0) | |
| self.assertEqual(info["target_contaminated_quantity"], 12) | |
| def test_phase3_full_plan_scores_high(self) -> None: | |
| grade = evaluate_action_plan( | |
| "phase3_mixed_shipments", | |
| [ | |
| {"type": "trace_lot", "lot_id": "LotA"}, | |
| {"type": "inspect_node", "node_id": "warehouse"}, | |
| {"type": "inspect_node", "node_id": "crossdock"}, | |
| {"type": "inspect_node", "node_id": "store1"}, | |
| {"type": "inspect_node", "node_id": "store2"}, | |
| {"type": "inspect_node", "node_id": "store3"}, | |
| {"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 30}, | |
| {"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 12}, | |
| {"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 10}, | |
| {"type": "quarantine", "node_id": "store2", "lot_id": "LotBlend", "quantity": 8}, | |
| {"type": "quarantine", "node_id": "store3", "lot_id": "LotBlend", "quantity": 4}, | |
| {"type": "notify", "node_id": "all"}, | |
| {"type": "finalize"}, | |
| ], | |
| ) | |
| self.assertGreaterEqual(grade.score, 0.95) | |
| self.assertTrue(grade.final_info["all_affected_stock_quarantined"]) | |
| if __name__ == "__main__": | |
| unittest.main() | |