Spaces:
Sleeping
Sleeping
| """Tests for services (parser, grid_generator, metrics).""" | |
| import pytest | |
| from app.services.parser import ( | |
| parse_initial_state, | |
| parse_traffic, | |
| parse_full_state, | |
| format_initial_state, | |
| format_traffic, | |
| ) | |
| from app.services.grid_generator import gen_grid | |
| from app.services.metrics import MetricsCollector, measure_performance | |
| from app.models.entities import Store, Destination, Tunnel | |
| class TestParseInitialState: | |
| """Tests for parse_initial_state function.""" | |
| def test_parse_basic(self): | |
| """Test parsing basic initial state.""" | |
| initial_state = "5;5;2;1;1,2,3,4;;0,0" | |
| width, height, stores, destinations, tunnels = parse_initial_state(initial_state) | |
| assert width == 5 | |
| assert height == 5 | |
| assert len(stores) == 1 | |
| assert len(destinations) == 2 | |
| assert len(tunnels) == 0 | |
| def test_parse_with_tunnels(self): | |
| """Test parsing initial state with tunnels.""" | |
| initial_state = "10;10;3;2;1,1,5,5,8,8;0,0,9,9,2,2,7,7;0,0,9,9" | |
| width, height, stores, destinations, tunnels = parse_initial_state(initial_state) | |
| assert width == 10 | |
| assert height == 10 | |
| assert len(stores) == 2 | |
| assert len(destinations) == 3 | |
| assert len(tunnels) == 2 | |
| # Check tunnel entrances | |
| assert tunnels[0].entrance1 == (0, 0) | |
| assert tunnels[0].entrance2 == (9, 9) | |
| assert tunnels[1].entrance1 == (2, 2) | |
| assert tunnels[1].entrance2 == (7, 7) | |
| def test_parse_store_positions(self): | |
| """Test parsing store positions.""" | |
| initial_state = "5;5;1;2;2,2;;0,0,4,4" | |
| width, height, stores, destinations, tunnels = parse_initial_state(initial_state) | |
| assert len(stores) == 2 | |
| assert stores[0].position == (0, 0) | |
| assert stores[1].position == (4, 4) | |
| def test_parse_destination_positions(self): | |
| """Test parsing destination positions.""" | |
| initial_state = "5;5;3;1;1,1,2,2,3,3;;0,0" | |
| width, height, stores, destinations, tunnels = parse_initial_state(initial_state) | |
| assert len(destinations) == 3 | |
| assert destinations[0].position == (1, 1) | |
| assert destinations[1].position == (2, 2) | |
| assert destinations[2].position == (3, 3) | |
| def test_parse_empty_tunnels(self): | |
| """Test parsing with no tunnels.""" | |
| initial_state = "3;3;1;1;1,1;;0,0" | |
| width, height, stores, destinations, tunnels = parse_initial_state(initial_state) | |
| assert len(tunnels) == 0 | |
| def test_parse_empty_destinations(self): | |
| """Test parsing with no destinations.""" | |
| initial_state = "3;3;0;1;;;0,0" | |
| width, height, stores, destinations, tunnels = parse_initial_state(initial_state) | |
| assert len(destinations) == 0 | |
| class TestParseTraffic: | |
| """Tests for parse_traffic function.""" | |
| def test_parse_basic_traffic(self): | |
| """Test parsing basic traffic string.""" | |
| traffic_str = "0,0,1,0,2;0,0,0,1,3;1,0,1,1,1" | |
| grid = parse_traffic(traffic_str, 3, 3) | |
| assert grid.width == 3 | |
| assert grid.height == 3 | |
| assert grid.get_traffic((0, 0), (1, 0)) == 2 | |
| assert grid.get_traffic((0, 0), (0, 1)) == 3 | |
| assert grid.get_traffic((1, 0), (1, 1)) == 1 | |
| def test_parse_blocked_segment(self): | |
| """Test parsing blocked segment (traffic=0).""" | |
| traffic_str = "0,0,1,0,0;0,0,0,1,1" | |
| grid = parse_traffic(traffic_str, 2, 2) | |
| assert grid.is_blocked((0, 0), (1, 0)) is True | |
| assert grid.is_blocked((0, 0), (0, 1)) is False | |
| def test_parse_empty_traffic(self): | |
| """Test parsing empty traffic string - should create default traffic.""" | |
| grid = parse_traffic("", 3, 3) | |
| assert grid.width == 3 | |
| assert grid.height == 3 | |
| # Should have default traffic level 1 | |
| assert grid.get_traffic((0, 0), (1, 0)) == 1 | |
| assert grid.get_traffic((0, 0), (0, 1)) == 1 | |
| class TestParseFullState: | |
| """Tests for parse_full_state function.""" | |
| def test_parse_full_state(self): | |
| """Test parsing complete state.""" | |
| initial_state = "5;5;2;1;1,1,3,3;;0,0" | |
| traffic_str = "0,0,1,0,2;0,0,0,1,1" | |
| state = parse_full_state(initial_state, traffic_str) | |
| assert state.grid.width == 5 | |
| assert state.grid.height == 5 | |
| assert len(state.stores) == 1 | |
| assert len(state.destinations) == 2 | |
| assert len(state.tunnels) == 0 | |
| class TestFormatInitialState: | |
| """Tests for format_initial_state function.""" | |
| def test_format_basic(self): | |
| """Test formatting basic state.""" | |
| stores = [Store(id=1, position=(0, 0))] | |
| destinations = [Destination(id=1, position=(2, 2))] | |
| tunnels = [] | |
| result = format_initial_state(5, 5, stores, destinations, tunnels) | |
| assert result == "5;5;1;1;2,2;;0,0" | |
| def test_format_with_tunnels(self): | |
| """Test formatting state with tunnels.""" | |
| stores = [Store(id=1, position=(0, 0)), Store(id=2, position=(4, 4))] | |
| destinations = [Destination(id=1, position=(2, 2))] | |
| tunnels = [Tunnel(entrance1=(1, 1), entrance2=(3, 3))] | |
| result = format_initial_state(5, 5, stores, destinations, tunnels) | |
| assert result == "5;5;1;2;2,2;1,1,3,3;0,0,4,4" | |
| def test_format_roundtrip(self): | |
| """Test that format and parse are inverses.""" | |
| stores = [Store(id=1, position=(0, 0)), Store(id=2, position=(4, 4))] | |
| destinations = [ | |
| Destination(id=1, position=(1, 1)), | |
| Destination(id=2, position=(3, 3)), | |
| ] | |
| tunnels = [Tunnel(entrance1=(0, 4), entrance2=(4, 0))] | |
| formatted = format_initial_state(5, 5, stores, destinations, tunnels) | |
| width, height, parsed_stores, parsed_dests, parsed_tunnels = parse_initial_state(formatted) | |
| assert width == 5 | |
| assert height == 5 | |
| assert len(parsed_stores) == 2 | |
| assert len(parsed_dests) == 2 | |
| assert len(parsed_tunnels) == 1 | |
| assert parsed_stores[0].position == (0, 0) | |
| assert parsed_stores[1].position == (4, 4) | |
| class TestFormatTraffic: | |
| """Tests for format_traffic function.""" | |
| def test_format_traffic(self, simple_grid): | |
| """Test formatting traffic.""" | |
| result = format_traffic(simple_grid) | |
| # Should contain semicolon-separated segments | |
| assert ";" in result or len(simple_grid.segments) <= 1 | |
| # Parse it back and verify | |
| parsed_grid = parse_traffic(result, 3, 3) | |
| assert parsed_grid.width == 3 | |
| assert parsed_grid.height == 3 | |
| class TestGridGenerator: | |
| """Tests for grid generator.""" | |
| def test_gen_grid_basic(self): | |
| """Test basic grid generation.""" | |
| initial_state, traffic, state = gen_grid( | |
| width=5, | |
| height=5, | |
| num_stores=1, | |
| num_destinations=2, | |
| obstacle_density=0.0, | |
| seed=42, | |
| ) | |
| assert state.grid.width == 5 | |
| assert state.grid.height == 5 | |
| assert len(state.stores) == 1 | |
| assert len(state.destinations) == 2 | |
| # Tunnels may be generated randomly since 0 is treated as falsy | |
| def test_gen_grid_with_tunnels(self): | |
| """Test grid generation with tunnels.""" | |
| initial_state, traffic, state = gen_grid( | |
| width=10, | |
| height=10, | |
| num_stores=2, | |
| num_destinations=3, | |
| num_tunnels=2, | |
| obstacle_density=0.1, | |
| seed=42, | |
| ) | |
| assert len(state.stores) == 2 | |
| assert len(state.destinations) == 3 | |
| # Tunnels might be fewer if generation fails to find valid positions | |
| assert len(state.tunnels) <= 2 | |
| def test_gen_grid_reproducible(self): | |
| """Test that same seed produces same grid.""" | |
| result1 = gen_grid(width=5, height=5, seed=12345) | |
| result2 = gen_grid(width=5, height=5, seed=12345) | |
| assert result1[0] == result2[0] # Same initial_state | |
| assert result1[1] == result2[1] # Same traffic | |
| def test_gen_grid_stores_limited(self): | |
| """Test that stores are limited to max 3.""" | |
| _, _, state = gen_grid( | |
| width=10, | |
| height=10, | |
| num_stores=10, # Request 10, should get max 3 | |
| num_destinations=1, | |
| seed=42, | |
| ) | |
| assert len(state.stores) <= 3 | |
| def test_gen_grid_destinations_limited(self): | |
| """Test that destinations are limited to max 10.""" | |
| _, _, state = gen_grid( | |
| width=10, | |
| height=10, | |
| num_stores=1, | |
| num_destinations=20, # Request 20, should get max 10 | |
| seed=42, | |
| ) | |
| assert len(state.destinations) <= 10 | |
| def test_gen_grid_connectivity(self): | |
| """Test that generated grid has connected paths.""" | |
| _, _, state = gen_grid( | |
| width=5, | |
| height=5, | |
| num_stores=1, | |
| num_destinations=1, | |
| obstacle_density=0.2, | |
| seed=42, | |
| ) | |
| # Should be able to reach destination from store | |
| # This is ensured by _ensure_connectivity in grid_generator | |
| store_pos = state.stores[0].position | |
| dest_pos = state.destinations[0].position | |
| # BFS to check connectivity | |
| visited = {store_pos} | |
| queue = [store_pos] | |
| while queue: | |
| current = queue.pop(0) | |
| if current == dest_pos: | |
| break | |
| for neighbor in state.grid.get_neighbors(current): | |
| if neighbor not in visited: | |
| visited.add(neighbor) | |
| queue.append(neighbor) | |
| assert dest_pos in visited, "Destination should be reachable from store" | |
| class TestMetricsCollector: | |
| """Tests for MetricsCollector.""" | |
| def test_metrics_collector_basic(self): | |
| """Test basic metrics collection.""" | |
| collector = MetricsCollector() | |
| collector.start() | |
| # Do some work | |
| _ = [i**2 for i in range(1000)] | |
| collector.sample() | |
| collector.stop() | |
| assert collector.runtime_ms > 0 | |
| assert collector.memory_kb >= 0 | |
| assert collector.cpu_percent >= 0 | |
| def test_metrics_collector_multiple_samples(self): | |
| """Test multiple samples.""" | |
| collector = MetricsCollector() | |
| collector.start() | |
| for _ in range(3): | |
| collector.sample() | |
| collector.stop() | |
| assert len(collector.memory_samples) >= 3 | |
| assert len(collector.cpu_samples) >= 3 | |
| def test_measure_performance_context_manager(self): | |
| """Test measure_performance context manager.""" | |
| with measure_performance() as metrics: | |
| # Do some work | |
| _ = sum(range(10000)) | |
| metrics.sample() | |
| assert metrics.runtime_ms > 0 | |