SearchAlgorithms / backend /tests /test_models.py
Kacemath's picture
feat: update with latest changes
47bba68
"""Tests for models (Grid, entities, state)."""
import pytest
from app.models.grid import Grid, Segment
from app.models.entities import Store, Destination, Tunnel
from app.models.state import SearchState, PathResult, SearchStep
class TestSegment:
"""Tests for Segment model."""
def test_segment_creation(self):
"""Test basic segment creation."""
segment = Segment(src=(0, 0), dst=(1, 0), traffic=2)
assert segment.src == (0, 0)
assert segment.dst == (1, 0)
assert segment.traffic == 2
def test_segment_normalization(self):
"""Test that segment normalizes direction (src < dst)."""
segment = Segment(src=(1, 0), dst=(0, 0), traffic=1)
assert segment.src == (0, 0)
assert segment.dst == (1, 0)
def test_segment_is_blocked(self):
"""Test blocked segment detection."""
blocked = Segment(src=(0, 0), dst=(1, 0), traffic=0)
passable = Segment(src=(0, 0), dst=(1, 0), traffic=1)
assert blocked.is_blocked is True
assert passable.is_blocked is False
def test_segment_get_key(self):
"""Test segment key generation."""
segment = Segment(src=(0, 0), dst=(1, 0), traffic=1)
assert segment.get_key() == ((0, 0), (1, 0))
class TestGrid:
"""Tests for Grid model."""
def test_grid_creation(self):
"""Test basic grid creation."""
grid = Grid(width=5, height=5)
assert grid.width == 5
assert grid.height == 5
assert len(grid.segments) == 0
def test_add_segment(self):
"""Test adding segments to grid."""
grid = Grid(width=3, height=3)
grid.add_segment((0, 0), (1, 0), 2)
assert len(grid.segments) == 1
segment = grid.get_segment((0, 0), (1, 0))
assert segment is not None
assert segment.traffic == 2
def test_add_segment_reversed(self):
"""Test adding segment with reversed coordinates."""
grid = Grid(width=3, height=3)
grid.add_segment((1, 0), (0, 0), 2)
# Should still be accessible both ways
segment = grid.get_segment((0, 0), (1, 0))
assert segment is not None
assert segment.traffic == 2
def test_get_traffic(self):
"""Test getting traffic level."""
grid = Grid(width=3, height=3)
grid.add_segment((0, 0), (1, 0), 3)
assert grid.get_traffic((0, 0), (1, 0)) == 3
assert grid.get_traffic((1, 0), (0, 0)) == 3 # Reversed
assert grid.get_traffic((0, 0), (0, 1)) == 0 # Non-existent
def test_is_blocked(self):
"""Test blocked segment detection."""
grid = Grid(width=3, height=3)
grid.add_segment((0, 0), (1, 0), 0)
grid.add_segment((0, 0), (0, 1), 1)
assert grid.is_blocked((0, 0), (1, 0)) is True
assert grid.is_blocked((0, 0), (0, 1)) is False
assert grid.is_blocked((1, 1), (2, 1)) is True # Non-existent = blocked
def test_is_valid_position(self):
"""Test position validation."""
grid = Grid(width=3, height=3)
assert grid.is_valid_position((0, 0)) is True
assert grid.is_valid_position((2, 2)) is True
assert grid.is_valid_position((3, 0)) is False
assert grid.is_valid_position((0, 3)) is False
assert grid.is_valid_position((-1, 0)) is False
def test_get_neighbors(self, simple_grid):
"""Test getting neighbors."""
# Corner (0,0) has 2 neighbors
neighbors = simple_grid.get_neighbors((0, 0))
assert len(neighbors) == 2
assert (1, 0) in neighbors
assert (0, 1) in neighbors
# Center (1,1) has 4 neighbors
neighbors = simple_grid.get_neighbors((1, 1))
assert len(neighbors) == 4
def test_get_neighbors_with_blocked(self, grid_with_blocked):
"""Test neighbors exclude blocked paths."""
# (1,1) normally has 4 neighbors but one path is blocked
neighbors = grid_with_blocked.get_neighbors((1, 1))
assert (2, 1) not in neighbors # Blocked
assert (0, 1) in neighbors
assert (1, 0) in neighbors
assert (1, 2) in neighbors
def test_to_dict(self, simple_grid):
"""Test grid serialization."""
result = simple_grid.to_dict()
assert result["width"] == 3
assert result["height"] == 3
assert "segments" in result
assert len(result["segments"]) > 0
class TestStore:
"""Tests for Store entity."""
def test_store_creation(self):
"""Test store creation."""
store = Store(id=1, position=(5, 3))
assert store.id == 1
assert store.position == (5, 3)
def test_store_to_dict(self):
"""Test store serialization."""
store = Store(id=2, position=(1, 2))
result = store.to_dict()
assert result["id"] == 2
assert result["position"]["x"] == 1
assert result["position"]["y"] == 2
class TestDestination:
"""Tests for Destination entity."""
def test_destination_creation(self):
"""Test destination creation."""
dest = Destination(id=1, position=(3, 4))
assert dest.id == 1
assert dest.position == (3, 4)
def test_destination_to_dict(self):
"""Test destination serialization."""
dest = Destination(id=3, position=(2, 5))
result = dest.to_dict()
assert result["id"] == 3
assert result["position"]["x"] == 2
assert result["position"]["y"] == 5
class TestTunnel:
"""Tests for Tunnel entity."""
def test_tunnel_creation(self):
"""Test tunnel creation."""
tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
assert tunnel.entrance1 == (0, 0)
assert tunnel.entrance2 == (5, 5)
def test_tunnel_cost(self):
"""Test tunnel cost calculation (Manhattan distance)."""
tunnel = Tunnel(entrance1=(0, 0), entrance2=(3, 4))
assert tunnel.cost == 7 # |3-0| + |4-0| = 7
def test_tunnel_cost_same_row(self):
"""Test tunnel cost on same row."""
tunnel = Tunnel(entrance1=(0, 5), entrance2=(10, 5))
assert tunnel.cost == 10
def test_tunnel_cost_same_column(self):
"""Test tunnel cost on same column."""
tunnel = Tunnel(entrance1=(3, 0), entrance2=(3, 7))
assert tunnel.cost == 7
def test_get_other_entrance(self):
"""Test getting other entrance."""
tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
assert tunnel.get_other_entrance((0, 0)) == (5, 5)
assert tunnel.get_other_entrance((5, 5)) == (0, 0)
def test_get_other_entrance_invalid(self):
"""Test error on invalid entrance."""
tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
with pytest.raises(ValueError):
tunnel.get_other_entrance((1, 1))
def test_has_entrance_at(self):
"""Test entrance detection."""
tunnel = Tunnel(entrance1=(0, 0), entrance2=(5, 5))
assert tunnel.has_entrance_at((0, 0)) is True
assert tunnel.has_entrance_at((5, 5)) is True
assert tunnel.has_entrance_at((1, 1)) is False
def test_tunnel_to_dict(self):
"""Test tunnel serialization."""
tunnel = Tunnel(entrance1=(1, 2), entrance2=(4, 6))
result = tunnel.to_dict()
assert result["entrance1"]["x"] == 1
assert result["entrance1"]["y"] == 2
assert result["entrance2"]["x"] == 4
assert result["entrance2"]["y"] == 6
assert result["cost"] == 7
class TestSearchState:
"""Tests for SearchState model."""
def test_search_state_creation(self, simple_grid, sample_stores, sample_destinations, sample_tunnels):
"""Test search state creation."""
state = SearchState(
grid=simple_grid,
stores=sample_stores,
destinations=sample_destinations,
tunnels=sample_tunnels,
)
assert state.grid == simple_grid
assert len(state.stores) == 2
assert len(state.destinations) == 3
assert len(state.tunnels) == 2
class TestPathResult:
"""Tests for PathResult model."""
def test_path_result_creation(self):
"""Test path result creation."""
result = PathResult(
plan="right,right,up",
cost=5.0,
nodes_expanded=10,
path=[(0, 0), (1, 0), (2, 0), (2, 1)],
)
assert result.plan == "right,right,up"
assert result.cost == 5.0
assert result.nodes_expanded == 10
assert len(result.path) == 4
def test_path_result_no_solution(self):
"""Test path result when no solution exists."""
result = PathResult(
plan="",
cost=float("inf"),
nodes_expanded=50,
path=[],
)
assert result.plan == ""
assert result.cost == float("inf")
assert len(result.path) == 0
class TestSearchStep:
"""Tests for SearchStep model."""
def test_search_step_creation(self):
"""Test search step creation."""
step = SearchStep(
step_number=5,
current_node=(2, 3),
action="right",
frontier=[(3, 3), (2, 4)],
explored=[(0, 0), (1, 0), (2, 0)],
current_path=[(0, 0), (1, 0), (2, 0), (2, 1), (2, 2), (2, 3)],
path_cost=6.0,
)
assert step.step_number == 5
assert step.current_node == (2, 3)
assert step.action == "right"
assert len(step.frontier) == 2
assert len(step.explored) == 3
def test_search_step_to_dict(self):
"""Test search step serialization."""
step = SearchStep(
step_number=0,
current_node=(0, 0),
action=None,
frontier=[(1, 0)],
explored=[],
current_path=[(0, 0)],
path_cost=0.0,
)
result = step.to_dict()
assert result["stepNumber"] == 0
assert result["currentNode"]["x"] == 0
assert result["currentNode"]["y"] == 0
assert result["action"] is None