Spaces:
Sleeping
Sleeping
| """Tests for models (Grid, entities, state).""" | |
| import pytest | |
| from app.models.grid import Grid, Segment | |
| from app.models.entities import Store, Destination, Tunnel | |
| from app.models.state import SearchState, PathResult, SearchStep | |
| class TestSegment: | |
| """Tests for Segment model.""" | |
| def test_segment_creation(self): | |
| """Test basic segment creation.""" | |
| segment = Segment(src=(0, 0), dst=(1, 0), traffic=2) | |
| assert segment.src == (0, 0) | |
| assert segment.dst == (1, 0) | |
| assert segment.traffic == 2 | |
| def test_segment_normalization(self): | |
| """Test that segment normalizes direction (src < dst).""" | |
| segment = Segment(src=(1, 0), dst=(0, 0), traffic=1) | |
| assert segment.src == (0, 0) | |
| assert segment.dst == (1, 0) | |
| def test_segment_is_blocked(self): | |
| """Test blocked segment detection.""" | |
| blocked = Segment(src=(0, 0), dst=(1, 0), traffic=0) | |
| passable = Segment(src=(0, 0), dst=(1, 0), traffic=1) | |
| assert blocked.is_blocked is True | |
| assert passable.is_blocked is False | |
| def test_segment_get_key(self): | |
| """Test segment key generation.""" | |
| segment = Segment(src=(0, 0), dst=(1, 0), traffic=1) | |
| assert segment.get_key() == ((0, 0), (1, 0)) | |
| class TestGrid: | |
| """Tests for Grid model.""" | |
| def test_grid_creation(self): | |
| """Test basic grid creation.""" | |
| grid = Grid(width=5, height=5) | |
| assert grid.width == 5 | |
| assert grid.height == 5 | |
| assert len(grid.segments) == 0 | |
| def test_add_segment(self): | |
| """Test adding segments to grid.""" | |
| grid = Grid(width=3, height=3) | |
| grid.add_segment((0, 0), (1, 0), 2) | |
| assert len(grid.segments) == 1 | |
| segment = grid.get_segment((0, 0), (1, 0)) | |
| assert segment is not None | |
| assert segment.traffic == 2 | |
| def test_add_segment_reversed(self): | |
| """Test adding segment with reversed coordinates.""" | |
| grid = Grid(width=3, height=3) | |
| grid.add_segment((1, 0), (0, 0), 2) | |
| # Should still be accessible both ways | |
| segment = grid.get_segment((0, 0), (1, 0)) | |
| assert segment is not None | |
| assert segment.traffic == 2 | |
| def test_get_traffic(self): | |
| """Test getting traffic level.""" | |
| grid = Grid(width=3, height=3) | |
| grid.add_segment((0, 0), (1, 0), 3) | |
| assert grid.get_traffic((0, 0), (1, 0)) == 3 | |
| assert grid.get_traffic((1, 0), (0, 0)) == 3 # Reversed | |
| assert grid.get_traffic((0, 0), (0, 1)) == 0 # Non-existent | |
| def test_is_blocked(self): | |
| """Test blocked segment detection.""" | |
| grid = Grid(width=3, height=3) | |
| grid.add_segment((0, 0), (1, 0), 0) | |
| grid.add_segment((0, 0), (0, 1), 1) | |
| assert grid.is_blocked((0, 0), (1, 0)) is True | |
| assert grid.is_blocked((0, 0), (0, 1)) is False | |
| assert grid.is_blocked((1, 1), (2, 1)) is True # Non-existent = blocked | |
| def test_is_valid_position(self): | |
| """Test position validation.""" | |
| grid = Grid(width=3, height=3) | |
| assert grid.is_valid_position((0, 0)) is True | |
| assert grid.is_valid_position((2, 2)) is True | |
| assert grid.is_valid_position((3, 0)) is False | |
| assert grid.is_valid_position((0, 3)) is False | |
| assert grid.is_valid_position((-1, 0)) is False | |
| def test_get_neighbors(self, simple_grid): | |
| """Test getting neighbors.""" | |
| # Corner (0,0) has 2 neighbors | |
| neighbors = simple_grid.get_neighbors((0, 0)) | |
| assert len(neighbors) == 2 | |
| assert (1, 0) in neighbors | |
| assert (0, 1) in neighbors | |
| # Center (1,1) has 4 neighbors | |
| neighbors = simple_grid.get_neighbors((1, 1)) | |
| assert len(neighbors) == 4 | |
| def test_get_neighbors_with_blocked(self, grid_with_blocked): | |
| """Test neighbors exclude blocked paths.""" | |
| # (1,1) normally has 4 neighbors but one path is blocked | |
| neighbors = grid_with_blocked.get_neighbors((1, 1)) | |
| assert (2, 1) not in neighbors # Blocked | |
| assert (0, 1) in neighbors | |
| assert (1, 0) in neighbors | |
| assert (1, 2) in neighbors | |
| def test_to_dict(self, simple_grid): | |
| """Test grid serialization.""" | |
| result = simple_grid.to_dict() | |
| assert result["width"] == 3 | |
| assert result["height"] == 3 | |
| assert "segments" in result | |
| assert len(result["segments"]) > 0 | |
| class TestStore: | |
| """Tests for Store entity.""" | |
| def test_store_creation(self): | |
| """Test store creation.""" | |
| store = Store(id=1, position=(5, 3)) | |
| assert store.id == 1 | |
| assert store.position == (5, 3) | |
| def test_store_to_dict(self): | |
| """Test store serialization.""" | |
| store = Store(id=2, position=(1, 2)) | |
| result = store.to_dict() | |
| assert result["id"] == 2 | |
| assert result["position"]["x"] == 1 | |
| assert result["position"]["y"] == 2 | |
| class TestDestination: | |
| """Tests for Destination entity.""" | |
| def test_destination_creation(self): | |
| """Test destination creation.""" | |
| dest = Destination(id=1, position=(3, 4)) | |
| assert dest.id == 1 | |
| assert dest.position == (3, 4) | |
| def test_destination_to_dict(self): | |
| """Test destination serialization.""" | |
| dest = Destination(id=3, position=(2, 5)) | |
| result = dest.to_dict() | |
| assert result["id"] == 3 | |
| assert result["position"]["x"] == 2 | |
| assert result["position"]["y"] == 5 | |
| class TestTunnel: | |
| """Tests for Tunnel entity.""" | |
| def test_tunnel_creation(self): | |
| """Test tunnel creation.""" | |
| tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5)) | |
| assert tunnel.entrance1 == (0, 0) | |
| assert tunnel.entrance2 == (5, 5) | |
| def test_tunnel_cost(self): | |
| """Test tunnel cost calculation (Manhattan distance).""" | |
| tunnel = Tunnel(entrance1=(0, 0), entrance2=(3, 4)) | |
| assert tunnel.cost == 7 # |3-0| + |4-0| = 7 | |
| def test_tunnel_cost_same_row(self): | |
| """Test tunnel cost on same row.""" | |
| tunnel = Tunnel(entrance1=(0, 5), entrance2=(10, 5)) | |
| assert tunnel.cost == 10 | |
| def test_tunnel_cost_same_column(self): | |
| """Test tunnel cost on same column.""" | |
| tunnel = Tunnel(entrance1=(3, 0), entrance2=(3, 7)) | |
| assert tunnel.cost == 7 | |
| def test_get_other_entrance(self): | |
| """Test getting other entrance.""" | |
| tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5)) | |
| assert tunnel.get_other_entrance((0, 0)) == (5, 5) | |
| assert tunnel.get_other_entrance((5, 5)) == (0, 0) | |
| def test_get_other_entrance_invalid(self): | |
| """Test error on invalid entrance.""" | |
| tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5)) | |
| with pytest.raises(ValueError): | |
| tunnel.get_other_entrance((1, 1)) | |
| def test_has_entrance_at(self): | |
| """Test entrance detection.""" | |
| tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5)) | |
| assert tunnel.has_entrance_at((0, 0)) is True | |
| assert tunnel.has_entrance_at((5, 5)) is True | |
| assert tunnel.has_entrance_at((1, 1)) is False | |
| def test_tunnel_to_dict(self): | |
| """Test tunnel serialization.""" | |
| tunnel = Tunnel(entrance1=(1, 2), entrance2=(4, 6)) | |
| result = tunnel.to_dict() | |
| assert result["entrance1"]["x"] == 1 | |
| assert result["entrance1"]["y"] == 2 | |
| assert result["entrance2"]["x"] == 4 | |
| assert result["entrance2"]["y"] == 6 | |
| assert result["cost"] == 7 | |
| class TestSearchState: | |
| """Tests for SearchState model.""" | |
| def test_search_state_creation(self, simple_grid, sample_stores, sample_destinations, sample_tunnels): | |
| """Test search state creation.""" | |
| state = SearchState( | |
| grid=simple_grid, | |
| stores=sample_stores, | |
| destinations=sample_destinations, | |
| tunnels=sample_tunnels, | |
| ) | |
| assert state.grid == simple_grid | |
| assert len(state.stores) == 2 | |
| assert len(state.destinations) == 3 | |
| assert len(state.tunnels) == 2 | |
| class TestPathResult: | |
| """Tests for PathResult model.""" | |
| def test_path_result_creation(self): | |
| """Test path result creation.""" | |
| result = PathResult( | |
| plan="right,right,up", | |
| cost=5.0, | |
| nodes_expanded=10, | |
| path=[(0, 0), (1, 0), (2, 0), (2, 1)], | |
| ) | |
| assert result.plan == "right,right,up" | |
| assert result.cost == 5.0 | |
| assert result.nodes_expanded == 10 | |
| assert len(result.path) == 4 | |
| def test_path_result_no_solution(self): | |
| """Test path result when no solution exists.""" | |
| result = PathResult( | |
| plan="", | |
| cost=float("inf"), | |
| nodes_expanded=50, | |
| path=[], | |
| ) | |
| assert result.plan == "" | |
| assert result.cost == float("inf") | |
| assert len(result.path) == 0 | |
| class TestSearchStep: | |
| """Tests for SearchStep model.""" | |
| def test_search_step_creation(self): | |
| """Test search step creation.""" | |
| step = SearchStep( | |
| step_number=5, | |
| current_node=(2, 3), | |
| action="right", | |
| frontier=[(3, 3), (2, 4)], | |
| explored=[(0, 0), (1, 0), (2, 0)], | |
| current_path=[(0, 0), (1, 0), (2, 0), (2, 1), (2, 2), (2, 3)], | |
| path_cost=6.0, | |
| ) | |
| assert step.step_number == 5 | |
| assert step.current_node == (2, 3) | |
| assert step.action == "right" | |
| assert len(step.frontier) == 2 | |
| assert len(step.explored) == 3 | |
| def test_search_step_to_dict(self): | |
| """Test search step serialization.""" | |
| step = SearchStep( | |
| step_number=0, | |
| current_node=(0, 0), | |
| action=None, | |
| frontier=[(1, 0)], | |
| explored=[], | |
| current_path=[(0, 0)], | |
| path_cost=0.0, | |
| ) | |
| result = step.to_dict() | |
| assert result["stepNumber"] == 0 | |
| assert result["currentNode"]["x"] == 0 | |
| assert result["currentNode"]["y"] == 0 | |
| assert result["action"] is None | |