metaa / tests /test_env.py
Majen's picture
Full project upload
aae7b06 verified
"""Unit tests for RecallTrace."""
from __future__ import annotations
import unittest
from env.env import RecallTraceEnv
from grader.grader import evaluate_action_plan
class RecallTraceEnvTests(unittest.TestCase):
def test_phase1_plan_scores_high(self) -> None:
grade = evaluate_action_plan(
"phase1_direct_recall",
[
{"type": "trace_lot", "lot_id": "LotA"},
{"type": "inspect_node", "node_id": "warehouse"},
{"type": "inspect_node", "node_id": "store1"},
{"type": "inspect_node", "node_id": "store2"},
{"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 100},
{"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 50},
{"type": "quarantine", "node_id": "store2", "lot_id": "LotA", "quantity": 20},
{"type": "notify", "node_id": "all"},
{"type": "finalize"},
],
)
self.assertGreaterEqual(grade.score, 0.95)
self.assertTrue(grade.success)
def test_phase2_trace_reveals_relabels(self) -> None:
env = RecallTraceEnv(task_id="phase2_relabel_recall")
env.reset()
observation, reward, done, info = env.step({"type": "trace_lot", "lot_id": "LotA"})
self.assertFalse(done)
self.assertGreater(reward, 0)
self.assertEqual(info["matched_lots"], ["LotA", "LotA_R1", "LotA_R2"])
self.assertIn("store3", observation.trace_results["LotA"]["affected_nodes"])
def test_phase3_mixed_inventory_requires_exact_quarantine(self) -> None:
env = RecallTraceEnv(task_id="phase3_mixed_shipments")
env.reset()
env.step({"type": "trace_lot", "lot_id": "LotA"})
env.step({"type": "inspect_node", "node_id": "crossdock"})
_, reward, _, info = env.step({"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 15})
self.assertLess(reward, 0)
self.assertEqual(info["target_contaminated_quantity"], 12)
def test_phase3_full_plan_scores_high(self) -> None:
grade = evaluate_action_plan(
"phase3_mixed_shipments",
[
{"type": "trace_lot", "lot_id": "LotA"},
{"type": "inspect_node", "node_id": "warehouse"},
{"type": "inspect_node", "node_id": "crossdock"},
{"type": "inspect_node", "node_id": "store1"},
{"type": "inspect_node", "node_id": "store2"},
{"type": "inspect_node", "node_id": "store3"},
{"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 30},
{"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 12},
{"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 10},
{"type": "quarantine", "node_id": "store2", "lot_id": "LotBlend", "quantity": 8},
{"type": "quarantine", "node_id": "store3", "lot_id": "LotBlend", "quantity": 4},
{"type": "notify", "node_id": "all"},
{"type": "finalize"},
],
)
self.assertGreaterEqual(grade.score, 0.95)
self.assertTrue(grade.final_info["all_affected_stock_quarantined"])
if __name__ == "__main__":
unittest.main()