File size: 10,120 Bytes
47bba68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""Tests for models (Grid, entities, state)."""

import pytest
from app.models.grid import Grid, Segment
from app.models.entities import Store, Destination, Tunnel
from app.models.state import SearchState, PathResult, SearchStep


class TestSegment:
    """Tests for Segment model."""

    def test_segment_creation(self):
        """Test basic segment creation."""
        segment = Segment(src=(0, 0), dst=(1, 0), traffic=2)
        assert segment.src == (0, 0)
        assert segment.dst == (1, 0)
        assert segment.traffic == 2

    def test_segment_normalization(self):
        """Test that segment normalizes direction (src < dst)."""
        segment = Segment(src=(1, 0), dst=(0, 0), traffic=1)
        assert segment.src == (0, 0)
        assert segment.dst == (1, 0)

    def test_segment_is_blocked(self):
        """Test blocked segment detection."""
        blocked = Segment(src=(0, 0), dst=(1, 0), traffic=0)
        passable = Segment(src=(0, 0), dst=(1, 0), traffic=1)
        assert blocked.is_blocked is True
        assert passable.is_blocked is False

    def test_segment_get_key(self):
        """Test segment key generation."""
        segment = Segment(src=(0, 0), dst=(1, 0), traffic=1)
        assert segment.get_key() == ((0, 0), (1, 0))


class TestGrid:
    """Tests for Grid model."""

    def test_grid_creation(self):
        """Test basic grid creation."""
        grid = Grid(width=5, height=5)
        assert grid.width == 5
        assert grid.height == 5
        assert len(grid.segments) == 0

    def test_add_segment(self):
        """Test adding segments to grid."""
        grid = Grid(width=3, height=3)
        grid.add_segment((0, 0), (1, 0), 2)
        assert len(grid.segments) == 1
        segment = grid.get_segment((0, 0), (1, 0))
        assert segment is not None
        assert segment.traffic == 2

    def test_add_segment_reversed(self):
        """Test adding segment with reversed coordinates."""
        grid = Grid(width=3, height=3)
        grid.add_segment((1, 0), (0, 0), 2)
        # Should still be accessible both ways
        segment = grid.get_segment((0, 0), (1, 0))
        assert segment is not None
        assert segment.traffic == 2

    def test_get_traffic(self):
        """Test getting traffic level."""
        grid = Grid(width=3, height=3)
        grid.add_segment((0, 0), (1, 0), 3)
        assert grid.get_traffic((0, 0), (1, 0)) == 3
        assert grid.get_traffic((1, 0), (0, 0)) == 3  # Reversed
        assert grid.get_traffic((0, 0), (0, 1)) == 0  # Non-existent

    def test_is_blocked(self):
        """Test blocked segment detection."""
        grid = Grid(width=3, height=3)
        grid.add_segment((0, 0), (1, 0), 0)
        grid.add_segment((0, 0), (0, 1), 1)
        assert grid.is_blocked((0, 0), (1, 0)) is True
        assert grid.is_blocked((0, 0), (0, 1)) is False
        assert grid.is_blocked((1, 1), (2, 1)) is True  # Non-existent = blocked

    def test_is_valid_position(self):
        """Test position validation."""
        grid = Grid(width=3, height=3)
        assert grid.is_valid_position((0, 0)) is True
        assert grid.is_valid_position((2, 2)) is True
        assert grid.is_valid_position((3, 0)) is False
        assert grid.is_valid_position((0, 3)) is False
        assert grid.is_valid_position((-1, 0)) is False

    def test_get_neighbors(self, simple_grid):
        """Test getting neighbors."""
        # Corner (0,0) has 2 neighbors
        neighbors = simple_grid.get_neighbors((0, 0))
        assert len(neighbors) == 2
        assert (1, 0) in neighbors
        assert (0, 1) in neighbors

        # Center (1,1) has 4 neighbors
        neighbors = simple_grid.get_neighbors((1, 1))
        assert len(neighbors) == 4

    def test_get_neighbors_with_blocked(self, grid_with_blocked):
        """Test neighbors exclude blocked paths."""
        # (1,1) normally has 4 neighbors but one path is blocked
        neighbors = grid_with_blocked.get_neighbors((1, 1))
        assert (2, 1) not in neighbors  # Blocked
        assert (0, 1) in neighbors
        assert (1, 0) in neighbors
        assert (1, 2) in neighbors

    def test_to_dict(self, simple_grid):
        """Test grid serialization."""
        result = simple_grid.to_dict()
        assert result["width"] == 3
        assert result["height"] == 3
        assert "segments" in result
        assert len(result["segments"]) > 0


class TestStore:
    """Tests for Store entity."""

    def test_store_creation(self):
        """Test store creation."""
        store = Store(id=1, position=(5, 3))
        assert store.id == 1
        assert store.position == (5, 3)

    def test_store_to_dict(self):
        """Test store serialization."""
        store = Store(id=2, position=(1, 2))
        result = store.to_dict()
        assert result["id"] == 2
        assert result["position"]["x"] == 1
        assert result["position"]["y"] == 2


class TestDestination:
    """Tests for Destination entity."""

    def test_destination_creation(self):
        """Test destination creation."""
        dest = Destination(id=1, position=(3, 4))
        assert dest.id == 1
        assert dest.position == (3, 4)

    def test_destination_to_dict(self):
        """Test destination serialization."""
        dest = Destination(id=3, position=(2, 5))
        result = dest.to_dict()
        assert result["id"] == 3
        assert result["position"]["x"] == 2
        assert result["position"]["y"] == 5


class TestTunnel:
    """Tests for Tunnel entity."""

    def test_tunnel_creation(self):
        """Test tunnel creation."""
        tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
        assert tunnel.entrance1 == (0, 0)
        assert tunnel.entrance2 == (5, 5)

    def test_tunnel_cost(self):
        """Test tunnel cost calculation (Manhattan distance)."""
        tunnel = Tunnel(entrance1=(0, 0), entrance2=(3, 4))
        assert tunnel.cost == 7  # |3-0| + |4-0| = 7

    def test_tunnel_cost_same_row(self):
        """Test tunnel cost on same row."""
        tunnel = Tunnel(entrance1=(0, 5), entrance2=(10, 5))
        assert tunnel.cost == 10

    def test_tunnel_cost_same_column(self):
        """Test tunnel cost on same column."""
        tunnel = Tunnel(entrance1=(3, 0), entrance2=(3, 7))
        assert tunnel.cost == 7

    def test_get_other_entrance(self):
        """Test getting other entrance."""
        tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
        assert tunnel.get_other_entrance((0, 0)) == (5, 5)
        assert tunnel.get_other_entrance((5, 5)) == (0, 0)

    def test_get_other_entrance_invalid(self):
        """Test error on invalid entrance."""
        tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
        with pytest.raises(ValueError):
            tunnel.get_other_entrance((1, 1))

    def test_has_entrance_at(self):
        """Test entrance detection."""
        tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
        assert tunnel.has_entrance_at((0, 0)) is True
        assert tunnel.has_entrance_at((5, 5)) is True
        assert tunnel.has_entrance_at((1, 1)) is False

    def test_tunnel_to_dict(self):
        """Test tunnel serialization."""
        tunnel = Tunnel(entrance1=(1, 2), entrance2=(4, 6))
        result = tunnel.to_dict()
        assert result["entrance1"]["x"] == 1
        assert result["entrance1"]["y"] == 2
        assert result["entrance2"]["x"] == 4
        assert result["entrance2"]["y"] == 6
        assert result["cost"] == 7


class TestSearchState:
    """Tests for SearchState model."""

    def test_search_state_creation(self, simple_grid, sample_stores, sample_destinations, sample_tunnels):
        """Test search state creation."""
        state = SearchState(
            grid=simple_grid,
            stores=sample_stores,
            destinations=sample_destinations,
            tunnels=sample_tunnels,
        )
        assert state.grid == simple_grid
        assert len(state.stores) == 2
        assert len(state.destinations) == 3
        assert len(state.tunnels) == 2


class TestPathResult:
    """Tests for PathResult model."""

    def test_path_result_creation(self):
        """Test path result creation."""
        result = PathResult(
            plan="right,right,up",
            cost=5.0,
            nodes_expanded=10,
            path=[(0, 0), (1, 0), (2, 0), (2, 1)],
        )
        assert result.plan == "right,right,up"
        assert result.cost == 5.0
        assert result.nodes_expanded == 10
        assert len(result.path) == 4

    def test_path_result_no_solution(self):
        """Test path result when no solution exists."""
        result = PathResult(
            plan="",
            cost=float("inf"),
            nodes_expanded=50,
            path=[],
        )
        assert result.plan == ""
        assert result.cost == float("inf")
        assert len(result.path) == 0


class TestSearchStep:
    """Tests for SearchStep model."""

    def test_search_step_creation(self):
        """Test search step creation."""
        step = SearchStep(
            step_number=5,
            current_node=(2, 3),
            action="right",
            frontier=[(3, 3), (2, 4)],
            explored=[(0, 0), (1, 0), (2, 0)],
            current_path=[(0, 0), (1, 0), (2, 0), (2, 1), (2, 2), (2, 3)],
            path_cost=6.0,
        )
        assert step.step_number == 5
        assert step.current_node == (2, 3)
        assert step.action == "right"
        assert len(step.frontier) == 2
        assert len(step.explored) == 3

    def test_search_step_to_dict(self):
        """Test search step serialization."""
        step = SearchStep(
            step_number=0,
            current_node=(0, 0),
            action=None,
            frontier=[(1, 0)],
            explored=[],
            current_path=[(0, 0)],
            path_cost=0.0,
        )
        result = step.to_dict()
        assert result["stepNumber"] == 0
        assert result["currentNode"]["x"] == 0
        assert result["currentNode"]["y"] == 0
        assert result["action"] is None