File size: 3,420 Bytes
6172160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b4eb04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6172160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Tests for the dispatch state machine."""

from __future__ import annotations

from src.city_schema import CitySchemaLoader
from src.models import Action, DispatchAction
from src.state_machine import DispatchStateMachine


def test_reset_sets_ids_and_has_entities() -> None:
    schema = CitySchemaLoader.load("metro_city")
    sm = DispatchStateMachine(schema=schema, seed=42)
    state = sm.reset(task_id="single_incident", episode_id="ep-1")
    assert state.task_id == "single_incident"
    assert state.episode_id == "ep-1"
    assert state.step_count == 0
    assert state.units
    assert state.incidents


def test_legal_actions_non_empty_initially() -> None:
    schema = CitySchemaLoader.load("metro_city")
    sm = DispatchStateMachine(schema=schema, seed=42)
    state = sm.reset(task_id="single_incident", episode_id="ep-1")
    legal = sm.get_legal_actions(state)
    assert legal
    assert any(a.action_type == DispatchAction.DISPATCH for a in legal)


def test_additional_actions_become_reachable() -> None:
    schema = CitySchemaLoader.load("metro_city")
    sm = DispatchStateMachine(schema=schema, seed=42)

    # Multi-incident is a better reachability surface (multiple incidents + P2 incident).
    state = sm.reset(task_id="multi_incident", episode_id="ep-1")
    legal = sm.get_legal_actions(state)

    assert any(a.action_type == DispatchAction.STAGE for a in legal)
    assert any(a.action_type == DispatchAction.UPGRADE for a in legal)
    assert any(a.action_type == DispatchAction.DOWNGRADE for a in legal)

    # After a dispatch, REASSIGN should be legal to the other active incident.
    dispatch = next(a for a in legal if a.action_type == DispatchAction.DISPATCH)
    state, _ = sm.step(state, dispatch)
    legal2 = sm.get_legal_actions(state)
    assert any(a.action_type == DispatchAction.CANCEL for a in legal2)
    assert any(a.action_type == DispatchAction.REASSIGN for a in legal2)


def test_mutual_aid_appears_when_type_exhausted() -> None:
    schema = CitySchemaLoader.load("metro_city")
    sm = DispatchStateMachine(schema=schema, seed=42)
    state = sm.reset(task_id="multi_incident", episode_id="ep-1")

    # Exhaust all MEDIC availability.
    from src.models import UnitStatus

    for unit in state.units.values():
        if unit.unit_type.value == "MEDIC":
            unit.status = UnitStatus.DISPATCHED
            unit.assigned_incident_id = "INC-001"

    legal = sm.get_legal_actions(state)
    assert any(a.action_type == DispatchAction.MUTUAL_AID for a in legal)


def test_invalid_action_yields_protocol_ok_false() -> None:
    schema = CitySchemaLoader.load("metro_city")
    sm = DispatchStateMachine(schema=schema, seed=42)
    state = sm.reset(task_id="single_incident", episode_id="ep-1")
    bad = Action(action_type=DispatchAction.DISPATCH, unit_id="NOPE", incident_id="INC-001")
    state2, obs = sm.step(state, bad)
    assert obs.protocol_ok is False
    assert obs.result == "invalid action"
    assert state2.step_count == 1


def test_dispatch_progresses_incident() -> None:
    schema = CitySchemaLoader.load("metro_city")
    sm = DispatchStateMachine(schema=schema, seed=42)
    state = sm.reset(task_id="single_incident", episode_id="ep-1")
    legal = sm.get_legal_actions(state)
    state, obs = sm.step(state, legal[0])
    assert obs.protocol_ok is True
    assert any(u.status.value != "AVAILABLE" for u in state.units.values())