File size: 3,391 Bytes
d19137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Unit tests for RecallTrace."""

from __future__ import annotations

import unittest

from env.env import RecallTraceEnv
from grader.grader import evaluate_action_plan


class RecallTraceEnvTests(unittest.TestCase):
    def test_phase1_plan_scores_high(self) -> None:
        grade = evaluate_action_plan(
            "phase1_direct_recall",
            [
                {"type": "trace_lot", "lot_id": "LotA"},
                {"type": "inspect_node", "node_id": "warehouse"},
                {"type": "inspect_node", "node_id": "store1"},
                {"type": "inspect_node", "node_id": "store2"},
                {"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 100},
                {"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 50},
                {"type": "quarantine", "node_id": "store2", "lot_id": "LotA", "quantity": 20},
                {"type": "notify", "node_id": "all"},
                {"type": "finalize"},
            ],
        )
        self.assertGreaterEqual(grade.score, 0.95)
        self.assertTrue(grade.success)

    def test_phase2_trace_reveals_relabels(self) -> None:
        env = RecallTraceEnv(task_id="phase2_relabel_recall")
        env.reset()
        observation, reward, done, info = env.step({"type": "trace_lot", "lot_id": "LotA"})
        self.assertFalse(done)
        self.assertGreater(reward, 0)
        self.assertEqual(info["matched_lots"], ["LotA", "LotA_R1", "LotA_R2"])
        self.assertIn("store3", observation.trace_results["LotA"]["affected_nodes"])

    def test_phase3_mixed_inventory_requires_exact_quarantine(self) -> None:
        env = RecallTraceEnv(task_id="phase3_mixed_shipments")
        env.reset()
        env.step({"type": "trace_lot", "lot_id": "LotA"})
        env.step({"type": "inspect_node", "node_id": "crossdock"})
        _, reward, _, info = env.step({"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 15})
        self.assertLess(reward, 0)
        self.assertEqual(info["target_contaminated_quantity"], 12)

    def test_phase3_full_plan_scores_high(self) -> None:
        grade = evaluate_action_plan(
            "phase3_mixed_shipments",
            [
                {"type": "trace_lot", "lot_id": "LotA"},
                {"type": "inspect_node", "node_id": "warehouse"},
                {"type": "inspect_node", "node_id": "crossdock"},
                {"type": "inspect_node", "node_id": "store1"},
                {"type": "inspect_node", "node_id": "store2"},
                {"type": "inspect_node", "node_id": "store3"},
                {"type": "quarantine", "node_id": "warehouse", "lot_id": "LotA", "quantity": 30},
                {"type": "quarantine", "node_id": "crossdock", "lot_id": "LotBlend", "quantity": 12},
                {"type": "quarantine", "node_id": "store1", "lot_id": "LotA", "quantity": 10},
                {"type": "quarantine", "node_id": "store2", "lot_id": "LotBlend", "quantity": 8},
                {"type": "quarantine", "node_id": "store3", "lot_id": "LotBlend", "quantity": 4},
                {"type": "notify", "node_id": "all"},
                {"type": "finalize"},
            ],
        )
        self.assertGreaterEqual(grade.score, 0.95)
        self.assertTrue(grade.final_info["all_affected_stock_quarantined"])


if __name__ == "__main__":
    unittest.main()