AmongUS / src /amongus_env /engine.py
5h4dy's picture
Upload folder using huggingface_hub
6c3d778 verified
from __future__ import annotations
import re
from collections import Counter
from typing import Iterable, Optional
from .models import (
Action,
CallMeeting,
Claim,
ClaimKind,
CompleteTask,
GameConfig,
Kill,
Move,
Observation,
PassMeeting,
Phase,
PlayerRole,
PlayerState,
ReportBody,
Speak,
TaskState,
Vote,
Winner,
)
ROOM_GRAPH = {
"Cafeteria": {"Electrical", "MedBay", "Admin"},
"Electrical": {"Cafeteria", "Storage", "Security"},
"MedBay": {"Cafeteria", "Security"},
"Admin": {"Cafeteria", "Storage"},
"Storage": {"Admin", "Electrical", "Navigation"},
"Navigation": {"Storage"},
"Security": {"Electrical", "MedBay"},
}
DEFAULT_TASKS = [
TaskState(room="Electrical", name="Fix Wiring"),
TaskState(room="MedBay", name="Submit Scan"),
TaskState(room="Admin", name="Swipe Card"),
]
class AmongUsEngine:
def __init__(
self,
seed: int = 0,
controlled_player_id: str = "red",
player_ids: Optional[list[str]] = None,
impostor_ids: Optional[Iterable[str]] = None,
) -> None:
self.config = GameConfig(
seed=seed,
controlled_player_id=controlled_player_id,
player_ids=player_ids or ["red", "blue", "green", "yellow"],
)
self.impostor_ids = set(impostor_ids or ["blue"])
self.players: dict[str, PlayerState] = {}
self.tasks_by_player: dict[str, list[TaskState]] = {}
self.location_history: dict[str, list[str]] = {}
self.discussion_log: list[str] = []
self.claims: list[Claim] = []
self.dead_bodies: set[str] = set()
self.message_log: list[str] = []
self.phase = Phase.TASKS
self.winner: Optional[Winner] = None
self.done = False
self.kill_cooldowns: dict[str, int] = {}
self.last_reward = 0.0
self.voting_open = False
self.meeting_turns_remaining = 0
@property
def controlled_player(self) -> PlayerState:
return self.players[self.config.controlled_player_id]
def reset(self) -> Observation:
self.players = {}
self.tasks_by_player = {}
self.location_history = {}
self.discussion_log = []
self.claims = []
self.dead_bodies = set()
self.message_log = ["Match reset"]
self.phase = Phase.TASKS
self.winner = None
self.done = False
self.last_reward = 0.0
self.voting_open = False
self.meeting_turns_remaining = 0
self.kill_cooldowns = {player_id: 0 for player_id in self.impostor_ids}
for player_id in self.config.player_ids:
role = (
PlayerRole.IMPOSTOR
if player_id in self.impostor_ids
else PlayerRole.CREWMATE
)
self.players[player_id] = PlayerState(
player_id=player_id,
role=role,
location="Cafeteria",
)
self.location_history[player_id] = ["Cafeteria"]
if role is PlayerRole.CREWMATE:
self.tasks_by_player[player_id] = [
task.model_copy() for task in DEFAULT_TASKS
]
return self.observe()
def observe(self, reward: Optional[float] = None) -> Observation:
player = self.controlled_player
visible_players = [
other.player_id
for other in self.players.values()
if other.player_id != player.player_id
and other.alive
and not other.ejected
and other.location == player.location
]
return Observation(
role=player.role,
location=player.location,
visible_players=sorted(visible_players),
task_list=self.tasks_by_player.get(player.player_id, []),
message_log=self.message_log[-20:],
discussion_log=self.discussion_log[-20:],
claims=self.claims[-20:],
phase=self.phase,
reward=self.last_reward if reward is None else reward,
done=self.done,
winner=self.winner,
voting_open=self.voting_open,
meeting_turns_remaining=self.meeting_turns_remaining,
)
def step(self, action: Action) -> Observation:
if self.done:
return self._illegal("Game is already complete")
if isinstance(action, Move):
reward = self._move(action.room)
elif isinstance(action, CompleteTask):
reward = self._complete_task()
elif isinstance(action, Kill):
reward = self._kill(action.target_id)
elif isinstance(action, ReportBody):
reward = self._report_body()
elif isinstance(action, CallMeeting):
reward = self._call_meeting()
elif isinstance(action, Vote):
reward = self._vote(action.target_id)
elif isinstance(action, Speak):
reward = self._speak(action.message)
elif isinstance(action, PassMeeting):
reward = self._pass_meeting()
else:
reward = self._illegal("Unsupported action").reward
if not self.done:
reward += self._check_win_conditions()
self.last_reward = reward
return self.observe(reward=reward)
def _move(self, room: str) -> float:
player = self.controlled_player
if self.phase is not Phase.TASKS:
return self._illegal("Cannot move during meeting").reward
if not player.alive or player.ejected:
return self._illegal("Eliminated players cannot move").reward
if room not in ROOM_GRAPH[player.location]:
return self._illegal(f"Invalid move from {player.location} to {room}").reward
player.location = room
self.location_history[player.player_id].append(room)
self._tick_cooldowns()
self.message_log.append(f"Moved to {room}")
return 0.0
def _complete_task(self) -> float:
player = self.controlled_player
if self.phase is not Phase.TASKS:
return self._illegal("Cannot complete tasks during meeting").reward
if player.role is not PlayerRole.CREWMATE:
return self._illegal("Impostors cannot complete real tasks").reward
for task in self.tasks_by_player.get(player.player_id, []):
if task.room == player.location and not task.completed:
task.completed = True
self.message_log.append(f"Completed task {task.name}")
return 0.2
return self._illegal(f"No incomplete task in {player.location}").reward
def _kill(self, target_id: str) -> float:
player = self.controlled_player
if self.phase is not Phase.TASKS:
return self._illegal("Cannot kill during meeting").reward
if player.role is not PlayerRole.IMPOSTOR:
return self._illegal("Crewmates cannot kill").reward
if self.kill_cooldowns.get(player.player_id, 0) > 0:
return self._illegal("Kill is on cooldown").reward
target = self.players.get(target_id)
if (
target is None
or target.role is PlayerRole.IMPOSTOR
or not target.alive
or target.ejected
or target.location != player.location
):
return self._illegal(f"Cannot kill {target_id}").reward
target.alive = False
self.dead_bodies.add(target_id)
self.kill_cooldowns[player.player_id] = 2
self.message_log.append(f"Killed {target_id}")
return 0.5
def _report_body(self) -> float:
player = self.controlled_player
body_here = any(
self.players[player_id].location == player.location
for player_id in self.dead_bodies
)
if self.phase is not Phase.TASKS or not body_here:
return self._illegal("No reportable body here").reward
self.phase = Phase.MEETING
self._start_meeting_protocol()
self.dead_bodies.clear()
self.message_log.append("Reported body")
return 0.0
def _call_meeting(self) -> float:
if self.phase is not Phase.TASKS:
return self._illegal("Meeting already active").reward
self.phase = Phase.MEETING
self._start_meeting_protocol()
self.message_log.append("Emergency meeting called")
return 0.0
def _start_meeting_protocol(self) -> None:
self.voting_open = False
self.meeting_turns_remaining = 1
def _speak(self, message: str) -> float:
if self.phase is not Phase.MEETING:
return self._illegal("Cannot speak outside meeting").reward
if self.voting_open:
return self._illegal("Voting is already open").reward
speaker_id = self.controlled_player.player_id
entry = f"{speaker_id}: {message}"
self.discussion_log.append(entry)
self.message_log.append(entry)
claim = self._parse_claim(speaker_id=speaker_id, message=message)
if claim is not None:
self.claims.append(claim)
if (
claim.kind is ClaimKind.SELF_LOCATION
and claim.truth_value is False
):
self._open_voting()
return -1.0
self._open_voting()
return 0.0
def _pass_meeting(self) -> float:
if self.phase is not Phase.MEETING:
return self._illegal("Cannot pass outside meeting").reward
if self.voting_open:
return self._illegal("Voting is already open").reward
speaker_id = self.controlled_player.player_id
entry = f"{speaker_id}: pass"
self.discussion_log.append(entry)
self.message_log.append(entry)
self._open_voting()
return 0.0
def _open_voting(self) -> None:
self.meeting_turns_remaining = 0
self.voting_open = True
def _parse_claim(self, speaker_id: str, message: str) -> Optional[Claim]:
room_pattern = "|".join(re.escape(room) for room in ROOM_GRAPH)
self_location_match = re.fullmatch(
rf"\s*i\s+was\s+in\s+({room_pattern})\s*",
message,
flags=re.IGNORECASE,
)
if self_location_match:
room = self._canonical_room(self_location_match.group(1))
return Claim(
kind=ClaimKind.SELF_LOCATION,
speaker_id=speaker_id,
room=room,
truth_value=room in self.location_history.get(speaker_id, []),
)
saw_player_match = re.fullmatch(
rf"\s*i\s+saw\s+([A-Za-z0-9_-]+)\s+in\s+({room_pattern})\s*",
message,
flags=re.IGNORECASE,
)
if saw_player_match:
target_id = saw_player_match.group(1)
room = self._canonical_room(saw_player_match.group(2))
return Claim(
kind=ClaimKind.SAW_PLAYER,
speaker_id=speaker_id,
target_id=target_id,
room=room,
truth_value=room in self.location_history.get(target_id, []),
)
return None
def _canonical_room(self, room: str) -> str:
lookup = {known_room.lower(): known_room for known_room in ROOM_GRAPH}
return lookup[room.lower()]
def _vote(self, target_id: str) -> float:
if self.phase is not Phase.MEETING:
return self._illegal("Cannot vote outside meeting").reward
if not self.voting_open:
return self._illegal("Cannot vote before discussion is complete").reward
target = self.players.get(target_id)
if target is None or target.ejected or not target.alive:
return self._illegal(f"Cannot vote for {target_id}").reward
ballots = {self.controlled_player.player_id: target_id}
ballots.update(self._bot_votes())
active_voters = [
player
for player in self.players.values()
if player.alive and not player.ejected
]
vote_counts = Counter(ballots.values())
if not vote_counts:
return self._no_majority()
top_target_id, top_votes = vote_counts.most_common(1)[0]
tied = list(vote_counts.values()).count(top_votes) > 1
has_majority = top_votes > len(active_voters) / 2
if tied or not has_majority:
return self._no_majority()
target = self.players[top_target_id]
target.ejected = True
target.alive = False
self.message_log.append(f"Ejected {top_target_id}")
reward = 0.0
controlled = self.controlled_player
if target.role is PlayerRole.IMPOSTOR:
reward += 0.5 if controlled.role is PlayerRole.CREWMATE else -0.5
elif target.player_id == controlled.player_id:
reward -= 0.5
self.phase = Phase.TASKS
self._reset_meeting_protocol()
return reward
def _bot_votes(self) -> dict[str, str]:
false_speaker_id = self._latest_false_self_location_speaker()
if false_speaker_id is None:
return {}
target = self.players.get(false_speaker_id)
if target is None or not target.alive or target.ejected:
return {}
return {
player.player_id: false_speaker_id
for player in self.players.values()
if player.player_id != self.controlled_player.player_id
and player.alive
and not player.ejected
}
def _latest_false_self_location_speaker(self) -> Optional[str]:
for claim in reversed(self.claims):
if (
claim.kind is ClaimKind.SELF_LOCATION
and claim.truth_value is False
):
return claim.speaker_id
return None
def _no_majority(self) -> float:
self.message_log.append("No majority; nobody ejected")
self.phase = Phase.TASKS
self._reset_meeting_protocol()
return 0.0
def _reset_meeting_protocol(self) -> None:
self.voting_open = False
self.meeting_turns_remaining = 0
def _check_win_conditions(self) -> float:
alive_crewmates = [
player
for player in self.players.values()
if player.role is PlayerRole.CREWMATE and player.alive and not player.ejected
]
alive_impostors = [
player
for player in self.players.values()
if player.role is PlayerRole.IMPOSTOR and player.alive and not player.ejected
]
if not alive_impostors:
return self._finish(Winner.CREWMATES)
all_tasks_done = all(
task.completed
for tasks in self.tasks_by_player.values()
for task in tasks
)
if all_tasks_done:
return self._finish(Winner.CREWMATES)
if len(alive_impostors) >= len(alive_crewmates):
return self._finish(Winner.IMPOSTORS)
return 0.0
def _finish(self, winner: Winner) -> float:
self.done = True
self.phase = Phase.COMPLETE
self._reset_meeting_protocol()
self.winner = winner
self.message_log.append(f"{winner.value} win")
return 1.0 if self._controlled_side_won(winner) else 0.0
def _controlled_side_won(self, winner: Winner) -> bool:
role = self.controlled_player.role
return (
winner is Winner.CREWMATES
and role is PlayerRole.CREWMATE
or winner is Winner.IMPOSTORS
and role is PlayerRole.IMPOSTOR
)
def _illegal(self, message: str) -> Observation:
self.message_log.append(message)
self.last_reward = -0.1
return self.observe(reward=-0.1)
def _tick_cooldowns(self) -> None:
for player_id, cooldown in list(self.kill_cooldowns.items()):
self.kill_cooldowns[player_id] = max(0, cooldown - 1)