openra-rl / tests /test_models.py
github-actions[bot]
Sync from GitHub ac82c3e
02f4a63
"""Tests for OpenRA-RL Pydantic models."""
import pytest
from openra_env.models import (
ActionType,
BuildingInfoModel,
CommandModel,
EconomyInfo,
MapInfoModel,
MilitaryInfo,
OpenRAAction,
OpenRAObservation,
OpenRAState,
ProductionInfoModel,
UnitInfoModel,
)
class TestActionType:
def test_enum_values(self):
assert ActionType.NO_OP == "no_op"
assert ActionType.MOVE == "move"
assert ActionType.ATTACK == "attack"
assert ActionType.BUILD == "build"
assert ActionType.TRAIN == "train"
def test_enum_from_string(self):
assert ActionType("move") == ActionType.MOVE
assert ActionType("no_op") == ActionType.NO_OP
def test_all_action_types_exist(self):
expected = {
"no_op", "move", "attack_move", "attack", "stop",
"harvest", "build", "train", "deploy", "sell",
"repair", "place_building", "cancel_production", "set_rally_point",
"guard", "set_stance", "enter_transport", "unload",
"power_down", "set_primary", "surrender",
}
actual = {a.value for a in ActionType}
assert actual == expected
class TestCommandModel:
def test_minimal_command(self):
cmd = CommandModel(action=ActionType.NO_OP)
assert cmd.action == ActionType.NO_OP
assert cmd.actor_id == 0
assert cmd.target_x == 0
assert cmd.queued is False
def test_move_command(self):
cmd = CommandModel(
action=ActionType.MOVE,
actor_id=42,
target_x=100,
target_y=200,
)
assert cmd.action == ActionType.MOVE
assert cmd.actor_id == 42
assert cmd.target_x == 100
assert cmd.target_y == 200
def test_attack_command(self):
cmd = CommandModel(
action=ActionType.ATTACK,
actor_id=10,
target_actor_id=99,
)
assert cmd.target_actor_id == 99
def test_build_command(self):
cmd = CommandModel(
action=ActionType.BUILD,
item_type="powr",
)
assert cmd.item_type == "powr"
def test_serialization_roundtrip(self):
cmd = CommandModel(
action=ActionType.MOVE,
actor_id=5,
target_x=10,
target_y=20,
queued=True,
)
data = cmd.model_dump()
restored = CommandModel(**data)
assert restored == cmd
class TestOpenRAAction:
def test_empty_action(self):
action = OpenRAAction()
assert action.commands == []
def test_single_command(self):
action = OpenRAAction(
commands=[CommandModel(action=ActionType.NO_OP)]
)
assert len(action.commands) == 1
def test_multiple_commands(self):
action = OpenRAAction(
commands=[
CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20),
CommandModel(action=ActionType.ATTACK, actor_id=2, target_actor_id=99),
CommandModel(action=ActionType.BUILD, item_type="powr"),
]
)
assert len(action.commands) == 3
assert action.commands[0].action == ActionType.MOVE
assert action.commands[1].action == ActionType.ATTACK
assert action.commands[2].action == ActionType.BUILD
def test_serialization_roundtrip(self):
action = OpenRAAction(
commands=[
CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20),
]
)
data = action.model_dump()
restored = OpenRAAction(**data)
assert len(restored.commands) == 1
assert restored.commands[0].actor_id == 1
class TestEconomyInfo:
def test_defaults(self):
eco = EconomyInfo()
assert eco.cash == 0
assert eco.ore == 0
assert eco.power_provided == 0
assert eco.power_drained == 0
assert eco.resource_capacity == 0
assert eco.harvester_count == 0
def test_with_values(self):
eco = EconomyInfo(cash=5000, power_provided=100, power_drained=60, harvester_count=2)
assert eco.cash == 5000
assert eco.power_provided == 100
assert eco.power_drained == 60
assert eco.harvester_count == 2
class TestMilitaryInfo:
def test_defaults(self):
mil = MilitaryInfo()
assert mil.units_killed == 0
assert mil.units_lost == 0
assert mil.army_value == 0
def test_with_values(self):
mil = MilitaryInfo(units_killed=5, units_lost=2, army_value=3000)
assert mil.units_killed == 5
assert mil.units_lost == 2
assert mil.army_value == 3000
class TestUnitInfoModel:
def test_required_fields(self):
unit = UnitInfoModel(actor_id=1, type="e1")
assert unit.actor_id == 1
assert unit.type == "e1"
assert unit.hp_percent == 1.0
assert unit.is_idle is True
def test_full_unit(self):
unit = UnitInfoModel(
actor_id=42,
type="1tnk",
pos_x=1024,
pos_y=2048,
cell_x=4,
cell_y=8,
hp_percent=0.75,
is_idle=False,
current_activity="Attack",
owner="Nod",
can_attack=True,
)
assert unit.hp_percent == 0.75
assert unit.is_idle is False
assert unit.can_attack is True
class TestBuildingInfoModel:
def test_required_fields(self):
bldg = BuildingInfoModel(actor_id=10, type="powr")
assert bldg.actor_id == 10
assert bldg.type == "powr"
assert bldg.is_powered is True
def test_producing_building(self):
bldg = BuildingInfoModel(
actor_id=20,
type="barr",
is_producing=True,
production_progress=0.5,
producing_item="e1",
)
assert bldg.is_producing is True
assert bldg.producing_item == "e1"
class TestProductionInfoModel:
def test_required_fields(self):
prod = ProductionInfoModel(queue_type="Infantry", item="e1")
assert prod.queue_type == "Infantry"
assert prod.item == "e1"
assert prod.progress == 0.0
assert prod.paused is False
class TestMapInfoModel:
def test_defaults(self):
m = MapInfoModel()
assert m.width == 0
assert m.height == 0
assert m.map_name == ""
def test_with_values(self):
m = MapInfoModel(width=128, height=128, map_name="Allied vs Soviet")
assert m.width == 128
assert m.map_name == "Allied vs Soviet"
class TestOpenRAObservation:
def test_default_observation(self):
obs = OpenRAObservation()
assert obs.tick == 0
assert obs.units == []
assert obs.buildings == []
assert obs.done is False
assert obs.result == ""
def test_full_observation(self):
obs = OpenRAObservation(
tick=150,
economy=EconomyInfo(cash=5000, power_provided=100),
military=MilitaryInfo(units_killed=3),
units=[
UnitInfoModel(actor_id=1, type="e1"),
UnitInfoModel(actor_id=2, type="1tnk"),
],
buildings=[
BuildingInfoModel(actor_id=10, type="powr"),
],
production=[
ProductionInfoModel(queue_type="Infantry", item="e1", progress=0.5),
],
visible_enemies=[
UnitInfoModel(actor_id=99, type="e1", owner="Enemy"),
],
map_info=MapInfoModel(width=128, height=128),
available_production=["e1", "e3", "1tnk"],
done=False,
reward=0.5,
result="",
)
assert obs.tick == 150
assert len(obs.units) == 2
assert len(obs.buildings) == 1
assert len(obs.production) == 1
assert len(obs.visible_enemies) == 1
assert obs.economy.cash == 5000
assert obs.available_production == ["e1", "e3", "1tnk"]
def test_terminal_observation(self):
obs = OpenRAObservation(done=True, result="win", reward=1.0)
assert obs.done is True
assert obs.result == "win"
def test_serialization_roundtrip(self):
obs = OpenRAObservation(
tick=100,
economy=EconomyInfo(cash=3000),
units=[UnitInfoModel(actor_id=1, type="e1")],
)
data = obs.model_dump()
restored = OpenRAObservation(**data)
assert restored.tick == 100
assert restored.economy.cash == 3000
assert len(restored.units) == 1
class TestOpenRAState:
def test_defaults(self):
state = OpenRAState()
assert state.game_tick == 0
assert state.map_name == ""
assert state.opponent_type == "bot_normal"
assert state.step_count == 0
def test_with_values(self):
state = OpenRAState(
episode_id="abc123",
step_count=50,
game_tick=500,
map_name="Test Map",
opponent_type="bot_hard",
)
assert state.episode_id == "abc123"
assert state.step_count == 50
assert state.game_tick == 500
def test_serialization_roundtrip(self):
state = OpenRAState(episode_id="test", game_tick=100)
data = state.model_dump()
restored = OpenRAState(**data)
assert restored.episode_id == "test"
assert restored.game_tick == 100