from __future__ import annotations import re from collections import Counter from typing import Iterable, Optional from .models import ( Action, CallMeeting, Claim, ClaimKind, CompleteTask, GameConfig, Kill, Move, Observation, PassMeeting, Phase, PlayerRole, PlayerState, ReportBody, Speak, TaskState, Vote, Winner, ) ROOM_GRAPH = { "Cafeteria": {"Electrical", "MedBay", "Admin"}, "Electrical": {"Cafeteria", "Storage", "Security"}, "MedBay": {"Cafeteria", "Security"}, "Admin": {"Cafeteria", "Storage"}, "Storage": {"Admin", "Electrical", "Navigation"}, "Navigation": {"Storage"}, "Security": {"Electrical", "MedBay"}, } DEFAULT_TASKS = [ TaskState(room="Electrical", name="Fix Wiring"), TaskState(room="MedBay", name="Submit Scan"), TaskState(room="Admin", name="Swipe Card"), ] class AmongUsEngine: def __init__( self, seed: int = 0, controlled_player_id: str = "red", player_ids: Optional[list[str]] = None, impostor_ids: Optional[Iterable[str]] = None, ) -> None: self.config = GameConfig( seed=seed, controlled_player_id=controlled_player_id, player_ids=player_ids or ["red", "blue", "green", "yellow"], ) self.impostor_ids = set(impostor_ids or ["blue"]) self.players: dict[str, PlayerState] = {} self.tasks_by_player: dict[str, list[TaskState]] = {} self.location_history: dict[str, list[str]] = {} self.discussion_log: list[str] = [] self.claims: list[Claim] = [] self.dead_bodies: set[str] = set() self.message_log: list[str] = [] self.phase = Phase.TASKS self.winner: Optional[Winner] = None self.done = False self.kill_cooldowns: dict[str, int] = {} self.last_reward = 0.0 self.voting_open = False self.meeting_turns_remaining = 0 @property def controlled_player(self) -> PlayerState: return self.players[self.config.controlled_player_id] def reset(self) -> Observation: self.players = {} self.tasks_by_player = {} self.location_history = {} self.discussion_log = [] self.claims = [] self.dead_bodies = set() self.message_log = ["Match reset"] self.phase = Phase.TASKS self.winner = None self.done = False self.last_reward = 0.0 self.voting_open = False self.meeting_turns_remaining = 0 self.kill_cooldowns = {player_id: 0 for player_id in self.impostor_ids} for player_id in self.config.player_ids: role = ( PlayerRole.IMPOSTOR if player_id in self.impostor_ids else PlayerRole.CREWMATE ) self.players[player_id] = PlayerState( player_id=player_id, role=role, location="Cafeteria", ) self.location_history[player_id] = ["Cafeteria"] if role is PlayerRole.CREWMATE: self.tasks_by_player[player_id] = [ task.model_copy() for task in DEFAULT_TASKS ] return self.observe() def observe(self, reward: Optional[float] = None) -> Observation: player = self.controlled_player visible_players = [ other.player_id for other in self.players.values() if other.player_id != player.player_id and other.alive and not other.ejected and other.location == player.location ] return Observation( role=player.role, location=player.location, visible_players=sorted(visible_players), task_list=self.tasks_by_player.get(player.player_id, []), message_log=self.message_log[-20:], discussion_log=self.discussion_log[-20:], claims=self.claims[-20:], phase=self.phase, reward=self.last_reward if reward is None else reward, done=self.done, winner=self.winner, voting_open=self.voting_open, meeting_turns_remaining=self.meeting_turns_remaining, ) def step(self, action: Action) -> Observation: if self.done: return self._illegal("Game is already complete") if isinstance(action, Move): reward = self._move(action.room) elif isinstance(action, CompleteTask): reward = self._complete_task() elif isinstance(action, Kill): reward = self._kill(action.target_id) elif isinstance(action, ReportBody): reward = self._report_body() elif isinstance(action, CallMeeting): reward = self._call_meeting() elif isinstance(action, Vote): reward = self._vote(action.target_id) elif isinstance(action, Speak): reward = self._speak(action.message) elif isinstance(action, PassMeeting): reward = self._pass_meeting() else: reward = self._illegal("Unsupported action").reward if not self.done: reward += self._check_win_conditions() self.last_reward = reward return self.observe(reward=reward) def _move(self, room: str) -> float: player = self.controlled_player if self.phase is not Phase.TASKS: return self._illegal("Cannot move during meeting").reward if not player.alive or player.ejected: return self._illegal("Eliminated players cannot move").reward if room not in ROOM_GRAPH[player.location]: return self._illegal(f"Invalid move from {player.location} to {room}").reward player.location = room self.location_history[player.player_id].append(room) self._tick_cooldowns() self.message_log.append(f"Moved to {room}") return 0.0 def _complete_task(self) -> float: player = self.controlled_player if self.phase is not Phase.TASKS: return self._illegal("Cannot complete tasks during meeting").reward if player.role is not PlayerRole.CREWMATE: return self._illegal("Impostors cannot complete real tasks").reward for task in self.tasks_by_player.get(player.player_id, []): if task.room == player.location and not task.completed: task.completed = True self.message_log.append(f"Completed task {task.name}") return 0.2 return self._illegal(f"No incomplete task in {player.location}").reward def _kill(self, target_id: str) -> float: player = self.controlled_player if self.phase is not Phase.TASKS: return self._illegal("Cannot kill during meeting").reward if player.role is not PlayerRole.IMPOSTOR: return self._illegal("Crewmates cannot kill").reward if self.kill_cooldowns.get(player.player_id, 0) > 0: return self._illegal("Kill is on cooldown").reward target = self.players.get(target_id) if ( target is None or target.role is PlayerRole.IMPOSTOR or not target.alive or target.ejected or target.location != player.location ): return self._illegal(f"Cannot kill {target_id}").reward target.alive = False self.dead_bodies.add(target_id) self.kill_cooldowns[player.player_id] = 2 self.message_log.append(f"Killed {target_id}") return 0.5 def _report_body(self) -> float: player = self.controlled_player body_here = any( self.players[player_id].location == player.location for player_id in self.dead_bodies ) if self.phase is not Phase.TASKS or not body_here: return self._illegal("No reportable body here").reward self.phase = Phase.MEETING self._start_meeting_protocol() self.dead_bodies.clear() self.message_log.append("Reported body") return 0.0 def _call_meeting(self) -> float: if self.phase is not Phase.TASKS: return self._illegal("Meeting already active").reward self.phase = Phase.MEETING self._start_meeting_protocol() self.message_log.append("Emergency meeting called") return 0.0 def _start_meeting_protocol(self) -> None: self.voting_open = False self.meeting_turns_remaining = 1 def _speak(self, message: str) -> float: if self.phase is not Phase.MEETING: return self._illegal("Cannot speak outside meeting").reward if self.voting_open: return self._illegal("Voting is already open").reward speaker_id = self.controlled_player.player_id entry = f"{speaker_id}: {message}" self.discussion_log.append(entry) self.message_log.append(entry) claim = self._parse_claim(speaker_id=speaker_id, message=message) if claim is not None: self.claims.append(claim) if ( claim.kind is ClaimKind.SELF_LOCATION and claim.truth_value is False ): self._open_voting() return -1.0 self._open_voting() return 0.0 def _pass_meeting(self) -> float: if self.phase is not Phase.MEETING: return self._illegal("Cannot pass outside meeting").reward if self.voting_open: return self._illegal("Voting is already open").reward speaker_id = self.controlled_player.player_id entry = f"{speaker_id}: pass" self.discussion_log.append(entry) self.message_log.append(entry) self._open_voting() return 0.0 def _open_voting(self) -> None: self.meeting_turns_remaining = 0 self.voting_open = True def _parse_claim(self, speaker_id: str, message: str) -> Optional[Claim]: room_pattern = "|".join(re.escape(room) for room in ROOM_GRAPH) self_location_match = re.fullmatch( rf"\s*i\s+was\s+in\s+({room_pattern})\s*", message, flags=re.IGNORECASE, ) if self_location_match: room = self._canonical_room(self_location_match.group(1)) return Claim( kind=ClaimKind.SELF_LOCATION, speaker_id=speaker_id, room=room, truth_value=room in self.location_history.get(speaker_id, []), ) saw_player_match = re.fullmatch( rf"\s*i\s+saw\s+([A-Za-z0-9_-]+)\s+in\s+({room_pattern})\s*", message, flags=re.IGNORECASE, ) if saw_player_match: target_id = saw_player_match.group(1) room = self._canonical_room(saw_player_match.group(2)) return Claim( kind=ClaimKind.SAW_PLAYER, speaker_id=speaker_id, target_id=target_id, room=room, truth_value=room in self.location_history.get(target_id, []), ) return None def _canonical_room(self, room: str) -> str: lookup = {known_room.lower(): known_room for known_room in ROOM_GRAPH} return lookup[room.lower()] def _vote(self, target_id: str) -> float: if self.phase is not Phase.MEETING: return self._illegal("Cannot vote outside meeting").reward if not self.voting_open: return self._illegal("Cannot vote before discussion is complete").reward target = self.players.get(target_id) if target is None or target.ejected or not target.alive: return self._illegal(f"Cannot vote for {target_id}").reward ballots = {self.controlled_player.player_id: target_id} ballots.update(self._bot_votes()) active_voters = [ player for player in self.players.values() if player.alive and not player.ejected ] vote_counts = Counter(ballots.values()) if not vote_counts: return self._no_majority() top_target_id, top_votes = vote_counts.most_common(1)[0] tied = list(vote_counts.values()).count(top_votes) > 1 has_majority = top_votes > len(active_voters) / 2 if tied or not has_majority: return self._no_majority() target = self.players[top_target_id] target.ejected = True target.alive = False self.message_log.append(f"Ejected {top_target_id}") reward = 0.0 controlled = self.controlled_player if target.role is PlayerRole.IMPOSTOR: reward += 0.5 if controlled.role is PlayerRole.CREWMATE else -0.5 elif target.player_id == controlled.player_id: reward -= 0.5 self.phase = Phase.TASKS self._reset_meeting_protocol() return reward def _bot_votes(self) -> dict[str, str]: false_speaker_id = self._latest_false_self_location_speaker() if false_speaker_id is None: return {} target = self.players.get(false_speaker_id) if target is None or not target.alive or target.ejected: return {} return { player.player_id: false_speaker_id for player in self.players.values() if player.player_id != self.controlled_player.player_id and player.alive and not player.ejected } def _latest_false_self_location_speaker(self) -> Optional[str]: for claim in reversed(self.claims): if ( claim.kind is ClaimKind.SELF_LOCATION and claim.truth_value is False ): return claim.speaker_id return None def _no_majority(self) -> float: self.message_log.append("No majority; nobody ejected") self.phase = Phase.TASKS self._reset_meeting_protocol() return 0.0 def _reset_meeting_protocol(self) -> None: self.voting_open = False self.meeting_turns_remaining = 0 def _check_win_conditions(self) -> float: alive_crewmates = [ player for player in self.players.values() if player.role is PlayerRole.CREWMATE and player.alive and not player.ejected ] alive_impostors = [ player for player in self.players.values() if player.role is PlayerRole.IMPOSTOR and player.alive and not player.ejected ] if not alive_impostors: return self._finish(Winner.CREWMATES) all_tasks_done = all( task.completed for tasks in self.tasks_by_player.values() for task in tasks ) if all_tasks_done: return self._finish(Winner.CREWMATES) if len(alive_impostors) >= len(alive_crewmates): return self._finish(Winner.IMPOSTORS) return 0.0 def _finish(self, winner: Winner) -> float: self.done = True self.phase = Phase.COMPLETE self._reset_meeting_protocol() self.winner = winner self.message_log.append(f"{winner.value} win") return 1.0 if self._controlled_side_won(winner) else 0.0 def _controlled_side_won(self, winner: Winner) -> bool: role = self.controlled_player.role return ( winner is Winner.CREWMATES and role is PlayerRole.CREWMATE or winner is Winner.IMPOSTORS and role is PlayerRole.IMPOSTOR ) def _illegal(self, message: str) -> Observation: self.message_log.append(message) self.last_reward = -0.1 return self.observe(reward=-0.1) def _tick_cooldowns(self) -> None: for player_id, cooldown in list(self.kill_cooldowns.items()): self.kill_cooldowns[player_id] = max(0, cooldown - 1)