"""Tests for OpenRA-RL Pydantic models.""" import pytest from openra_env.models import ( ActionType, BuildingInfoModel, CommandModel, EconomyInfo, MapInfoModel, MilitaryInfo, OpenRAAction, OpenRAObservation, OpenRAState, ProductionInfoModel, UnitInfoModel, ) class TestActionType: def test_enum_values(self): assert ActionType.NO_OP == "no_op" assert ActionType.MOVE == "move" assert ActionType.ATTACK == "attack" assert ActionType.BUILD == "build" assert ActionType.TRAIN == "train" def test_enum_from_string(self): assert ActionType("move") == ActionType.MOVE assert ActionType("no_op") == ActionType.NO_OP def test_all_action_types_exist(self): expected = { "no_op", "move", "attack_move", "attack", "stop", "harvest", "build", "train", "deploy", "sell", "repair", "place_building", "cancel_production", "set_rally_point", "guard", "set_stance", "enter_transport", "unload", "power_down", "set_primary", "surrender", } actual = {a.value for a in ActionType} assert actual == expected class TestCommandModel: def test_minimal_command(self): cmd = CommandModel(action=ActionType.NO_OP) assert cmd.action == ActionType.NO_OP assert cmd.actor_id == 0 assert cmd.target_x == 0 assert cmd.queued is False def test_move_command(self): cmd = CommandModel( action=ActionType.MOVE, actor_id=42, target_x=100, target_y=200, ) assert cmd.action == ActionType.MOVE assert cmd.actor_id == 42 assert cmd.target_x == 100 assert cmd.target_y == 200 def test_attack_command(self): cmd = CommandModel( action=ActionType.ATTACK, actor_id=10, target_actor_id=99, ) assert cmd.target_actor_id == 99 def test_build_command(self): cmd = CommandModel( action=ActionType.BUILD, item_type="powr", ) assert cmd.item_type == "powr" def test_serialization_roundtrip(self): cmd = CommandModel( action=ActionType.MOVE, actor_id=5, target_x=10, target_y=20, queued=True, ) data = cmd.model_dump() restored = CommandModel(**data) assert restored == cmd class TestOpenRAAction: def test_empty_action(self): action = OpenRAAction() assert action.commands == [] def test_single_command(self): action = OpenRAAction( commands=[CommandModel(action=ActionType.NO_OP)] ) assert len(action.commands) == 1 def test_multiple_commands(self): action = OpenRAAction( commands=[ CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20), CommandModel(action=ActionType.ATTACK, actor_id=2, target_actor_id=99), CommandModel(action=ActionType.BUILD, item_type="powr"), ] ) assert len(action.commands) == 3 assert action.commands[0].action == ActionType.MOVE assert action.commands[1].action == ActionType.ATTACK assert action.commands[2].action == ActionType.BUILD def test_serialization_roundtrip(self): action = OpenRAAction( commands=[ CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20), ] ) data = action.model_dump() restored = OpenRAAction(**data) assert len(restored.commands) == 1 assert restored.commands[0].actor_id == 1 class TestEconomyInfo: def test_defaults(self): eco = EconomyInfo() assert eco.cash == 0 assert eco.ore == 0 assert eco.power_provided == 0 assert eco.power_drained == 0 assert eco.resource_capacity == 0 assert eco.harvester_count == 0 def test_with_values(self): eco = EconomyInfo(cash=5000, power_provided=100, power_drained=60, harvester_count=2) assert eco.cash == 5000 assert eco.power_provided == 100 assert eco.power_drained == 60 assert eco.harvester_count == 2 class TestMilitaryInfo: def test_defaults(self): mil = MilitaryInfo() assert mil.units_killed == 0 assert mil.units_lost == 0 assert mil.army_value == 0 def test_with_values(self): mil = MilitaryInfo(units_killed=5, units_lost=2, army_value=3000) assert mil.units_killed == 5 assert mil.units_lost == 2 assert mil.army_value == 3000 class TestUnitInfoModel: def test_required_fields(self): unit = UnitInfoModel(actor_id=1, type="e1") assert unit.actor_id == 1 assert unit.type == "e1" assert unit.hp_percent == 1.0 assert unit.is_idle is True def test_full_unit(self): unit = UnitInfoModel( actor_id=42, type="1tnk", pos_x=1024, pos_y=2048, cell_x=4, cell_y=8, hp_percent=0.75, is_idle=False, current_activity="Attack", owner="Nod", can_attack=True, ) assert unit.hp_percent == 0.75 assert unit.is_idle is False assert unit.can_attack is True class TestBuildingInfoModel: def test_required_fields(self): bldg = BuildingInfoModel(actor_id=10, type="powr") assert bldg.actor_id == 10 assert bldg.type == "powr" assert bldg.is_powered is True def test_producing_building(self): bldg = BuildingInfoModel( actor_id=20, type="barr", is_producing=True, production_progress=0.5, producing_item="e1", ) assert bldg.is_producing is True assert bldg.producing_item == "e1" class TestProductionInfoModel: def test_required_fields(self): prod = ProductionInfoModel(queue_type="Infantry", item="e1") assert prod.queue_type == "Infantry" assert prod.item == "e1" assert prod.progress == 0.0 assert prod.paused is False class TestMapInfoModel: def test_defaults(self): m = MapInfoModel() assert m.width == 0 assert m.height == 0 assert m.map_name == "" def test_with_values(self): m = MapInfoModel(width=128, height=128, map_name="Allied vs Soviet") assert m.width == 128 assert m.map_name == "Allied vs Soviet" class TestOpenRAObservation: def test_default_observation(self): obs = OpenRAObservation() assert obs.tick == 0 assert obs.units == [] assert obs.buildings == [] assert obs.done is False assert obs.result == "" def test_full_observation(self): obs = OpenRAObservation( tick=150, economy=EconomyInfo(cash=5000, power_provided=100), military=MilitaryInfo(units_killed=3), units=[ UnitInfoModel(actor_id=1, type="e1"), UnitInfoModel(actor_id=2, type="1tnk"), ], buildings=[ BuildingInfoModel(actor_id=10, type="powr"), ], production=[ ProductionInfoModel(queue_type="Infantry", item="e1", progress=0.5), ], visible_enemies=[ UnitInfoModel(actor_id=99, type="e1", owner="Enemy"), ], map_info=MapInfoModel(width=128, height=128), available_production=["e1", "e3", "1tnk"], done=False, reward=0.5, result="", ) assert obs.tick == 150 assert len(obs.units) == 2 assert len(obs.buildings) == 1 assert len(obs.production) == 1 assert len(obs.visible_enemies) == 1 assert obs.economy.cash == 5000 assert obs.available_production == ["e1", "e3", "1tnk"] def test_terminal_observation(self): obs = OpenRAObservation(done=True, result="win", reward=1.0) assert obs.done is True assert obs.result == "win" def test_serialization_roundtrip(self): obs = OpenRAObservation( tick=100, economy=EconomyInfo(cash=3000), units=[UnitInfoModel(actor_id=1, type="e1")], ) data = obs.model_dump() restored = OpenRAObservation(**data) assert restored.tick == 100 assert restored.economy.cash == 3000 assert len(restored.units) == 1 class TestOpenRAState: def test_defaults(self): state = OpenRAState() assert state.game_tick == 0 assert state.map_name == "" assert state.opponent_type == "bot_normal" assert state.step_count == 0 def test_with_values(self): state = OpenRAState( episode_id="abc123", step_count=50, game_tick=500, map_name="Test Map", opponent_type="bot_hard", ) assert state.episode_id == "abc123" assert state.step_count == 50 assert state.game_tick == 500 def test_serialization_roundtrip(self): state = OpenRAState(episode_id="test", game_tick=100) data = state.model_dump() restored = OpenRAState(**data) assert restored.episode_id == "test" assert restored.game_tick == 100