SearchAlgorithms / backend /tests /test_services.py
Kacemath's picture
feat: update with latest changes
47bba68
"""Tests for services (parser, grid_generator, metrics)."""
import pytest
from app.services.parser import (
parse_initial_state,
parse_traffic,
parse_full_state,
format_initial_state,
format_traffic,
)
from app.services.grid_generator import gen_grid
from app.services.metrics import MetricsCollector, measure_performance
from app.models.entities import Store, Destination, Tunnel
class TestParseInitialState:
"""Tests for parse_initial_state function."""
def test_parse_basic(self):
"""Test parsing basic initial state."""
initial_state = "5;5;2;1;1,2,3,4;;0,0"
width, height, stores, destinations, tunnels = parse_initial_state(initial_state)
assert width == 5
assert height == 5
assert len(stores) == 1
assert len(destinations) == 2
assert len(tunnels) == 0
def test_parse_with_tunnels(self):
"""Test parsing initial state with tunnels."""
initial_state = "10;10;3;2;1,1,5,5,8,8;0,0,9,9,2,2,7,7;0,0,9,9"
width, height, stores, destinations, tunnels = parse_initial_state(initial_state)
assert width == 10
assert height == 10
assert len(stores) == 2
assert len(destinations) == 3
assert len(tunnels) == 2
# Check tunnel entrances
assert tunnels[0].entrance1 == (0, 0)
assert tunnels[0].entrance2 == (9, 9)
assert tunnels[1].entrance1 == (2, 2)
assert tunnels[1].entrance2 == (7, 7)
def test_parse_store_positions(self):
"""Test parsing store positions."""
initial_state = "5;5;1;2;2,2;;0,0,4,4"
width, height, stores, destinations, tunnels = parse_initial_state(initial_state)
assert len(stores) == 2
assert stores[0].position == (0, 0)
assert stores[1].position == (4, 4)
def test_parse_destination_positions(self):
"""Test parsing destination positions."""
initial_state = "5;5;3;1;1,1,2,2,3,3;;0,0"
width, height, stores, destinations, tunnels = parse_initial_state(initial_state)
assert len(destinations) == 3
assert destinations[0].position == (1, 1)
assert destinations[1].position == (2, 2)
assert destinations[2].position == (3, 3)
def test_parse_empty_tunnels(self):
"""Test parsing with no tunnels."""
initial_state = "3;3;1;1;1,1;;0,0"
width, height, stores, destinations, tunnels = parse_initial_state(initial_state)
assert len(tunnels) == 0
def test_parse_empty_destinations(self):
"""Test parsing with no destinations."""
initial_state = "3;3;0;1;;;0,0"
width, height, stores, destinations, tunnels = parse_initial_state(initial_state)
assert len(destinations) == 0
class TestParseTraffic:
"""Tests for parse_traffic function."""
def test_parse_basic_traffic(self):
"""Test parsing basic traffic string."""
traffic_str = "0,0,1,0,2;0,0,0,1,3;1,0,1,1,1"
grid = parse_traffic(traffic_str, 3, 3)
assert grid.width == 3
assert grid.height == 3
assert grid.get_traffic((0, 0), (1, 0)) == 2
assert grid.get_traffic((0, 0), (0, 1)) == 3
assert grid.get_traffic((1, 0), (1, 1)) == 1
def test_parse_blocked_segment(self):
"""Test parsing blocked segment (traffic=0)."""
traffic_str = "0,0,1,0,0;0,0,0,1,1"
grid = parse_traffic(traffic_str, 2, 2)
assert grid.is_blocked((0, 0), (1, 0)) is True
assert grid.is_blocked((0, 0), (0, 1)) is False
def test_parse_empty_traffic(self):
"""Test parsing empty traffic string - should create default traffic."""
grid = parse_traffic("", 3, 3)
assert grid.width == 3
assert grid.height == 3
# Should have default traffic level 1
assert grid.get_traffic((0, 0), (1, 0)) == 1
assert grid.get_traffic((0, 0), (0, 1)) == 1
class TestParseFullState:
"""Tests for parse_full_state function."""
def test_parse_full_state(self):
"""Test parsing complete state."""
initial_state = "5;5;2;1;1,1,3,3;;0,0"
traffic_str = "0,0,1,0,2;0,0,0,1,1"
state = parse_full_state(initial_state, traffic_str)
assert state.grid.width == 5
assert state.grid.height == 5
assert len(state.stores) == 1
assert len(state.destinations) == 2
assert len(state.tunnels) == 0
class TestFormatInitialState:
"""Tests for format_initial_state function."""
def test_format_basic(self):
"""Test formatting basic state."""
stores = [Store(id=1, position=(0, 0))]
destinations = [Destination(id=1, position=(2, 2))]
tunnels = []
result = format_initial_state(5, 5, stores, destinations, tunnels)
assert result == "5;5;1;1;2,2;;0,0"
def test_format_with_tunnels(self):
"""Test formatting state with tunnels."""
stores = [Store(id=1, position=(0, 0)), Store(id=2, position=(4, 4))]
destinations = [Destination(id=1, position=(2, 2))]
tunnels = [Tunnel(entrance1=(1, 1), entrance2=(3, 3))]
result = format_initial_state(5, 5, stores, destinations, tunnels)
assert result == "5;5;1;2;2,2;1,1,3,3;0,0,4,4"
def test_format_roundtrip(self):
"""Test that format and parse are inverses."""
stores = [Store(id=1, position=(0, 0)), Store(id=2, position=(4, 4))]
destinations = [
Destination(id=1, position=(1, 1)),
Destination(id=2, position=(3, 3)),
]
tunnels = [Tunnel(entrance1=(0, 4), entrance2=(4, 0))]
formatted = format_initial_state(5, 5, stores, destinations, tunnels)
width, height, parsed_stores, parsed_dests, parsed_tunnels = parse_initial_state(formatted)
assert width == 5
assert height == 5
assert len(parsed_stores) == 2
assert len(parsed_dests) == 2
assert len(parsed_tunnels) == 1
assert parsed_stores[0].position == (0, 0)
assert parsed_stores[1].position == (4, 4)
class TestFormatTraffic:
"""Tests for format_traffic function."""
def test_format_traffic(self, simple_grid):
"""Test formatting traffic."""
result = format_traffic(simple_grid)
# Should contain semicolon-separated segments
assert ";" in result or len(simple_grid.segments) <= 1
# Parse it back and verify
parsed_grid = parse_traffic(result, 3, 3)
assert parsed_grid.width == 3
assert parsed_grid.height == 3
class TestGridGenerator:
"""Tests for grid generator."""
def test_gen_grid_basic(self):
"""Test basic grid generation."""
initial_state, traffic, state = gen_grid(
width=5,
height=5,
num_stores=1,
num_destinations=2,
obstacle_density=0.0,
seed=42,
)
assert state.grid.width == 5
assert state.grid.height == 5
assert len(state.stores) == 1
assert len(state.destinations) == 2
# Tunnels may be generated randomly since 0 is treated as falsy
def test_gen_grid_with_tunnels(self):
"""Test grid generation with tunnels."""
initial_state, traffic, state = gen_grid(
width=10,
height=10,
num_stores=2,
num_destinations=3,
num_tunnels=2,
obstacle_density=0.1,
seed=42,
)
assert len(state.stores) == 2
assert len(state.destinations) == 3
# Tunnels might be fewer if generation fails to find valid positions
assert len(state.tunnels) <= 2
def test_gen_grid_reproducible(self):
"""Test that same seed produces same grid."""
result1 = gen_grid(width=5, height=5, seed=12345)
result2 = gen_grid(width=5, height=5, seed=12345)
assert result1[0] == result2[0] # Same initial_state
assert result1[1] == result2[1] # Same traffic
def test_gen_grid_stores_limited(self):
"""Test that stores are limited to max 3."""
_, _, state = gen_grid(
width=10,
height=10,
num_stores=10, # Request 10, should get max 3
num_destinations=1,
seed=42,
)
assert len(state.stores) <= 3
def test_gen_grid_destinations_limited(self):
"""Test that destinations are limited to max 10."""
_, _, state = gen_grid(
width=10,
height=10,
num_stores=1,
num_destinations=20, # Request 20, should get max 10
seed=42,
)
assert len(state.destinations) <= 10
def test_gen_grid_connectivity(self):
"""Test that generated grid has connected paths."""
_, _, state = gen_grid(
width=5,
height=5,
num_stores=1,
num_destinations=1,
obstacle_density=0.2,
seed=42,
)
# Should be able to reach destination from store
# This is ensured by _ensure_connectivity in grid_generator
store_pos = state.stores[0].position
dest_pos = state.destinations[0].position
# BFS to check connectivity
visited = {store_pos}
queue = [store_pos]
while queue:
current = queue.pop(0)
if current == dest_pos:
break
for neighbor in state.grid.get_neighbors(current):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
assert dest_pos in visited, "Destination should be reachable from store"
class TestMetricsCollector:
"""Tests for MetricsCollector."""
def test_metrics_collector_basic(self):
"""Test basic metrics collection."""
collector = MetricsCollector()
collector.start()
# Do some work
_ = [i**2 for i in range(1000)]
collector.sample()
collector.stop()
assert collector.runtime_ms > 0
assert collector.memory_kb >= 0
assert collector.cpu_percent >= 0
def test_metrics_collector_multiple_samples(self):
"""Test multiple samples."""
collector = MetricsCollector()
collector.start()
for _ in range(3):
collector.sample()
collector.stop()
assert len(collector.memory_samples) >= 3
assert len(collector.cpu_samples) >= 3
def test_measure_performance_context_manager(self):
"""Test measure_performance context manager."""
with measure_performance() as metrics:
# Do some work
_ = sum(range(10000))
metrics.sample()
assert metrics.runtime_ms > 0