openspiel_env / server /opponent_policies.py
sergiopaniego's picture
sergiopaniego HF Staff
Upload folder using huggingface_hub
c65b2a4 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Opponent policies for multi-player OpenSpiel games.
These policies are used to control non-agent players in multi-player games,
allowing single-agent RL training against fixed or adaptive opponents.
"""
import random
from typing import Any, Protocol
class OpponentPolicy(Protocol):
"""Protocol for opponent policies."""
def select_action(self, legal_actions: list[int], observations: dict[str, Any]) -> int:
"""
Select an action for the opponent.
Args:
legal_actions: List of legal action IDs.
observations: Current observations from the environment.
Returns:
Selected action ID.
"""
...
class RandomOpponent:
"""Random opponent that selects uniformly from legal actions."""
def select_action(self, legal_actions: list[int], observations: dict[str, Any]) -> int:
"""Select a random legal action."""
if not legal_actions:
raise ValueError("No legal actions available")
return random.choice(legal_actions)
class FixedActionOpponent:
"""Opponent that always selects the same action (e.g., first legal action)."""
def __init__(self, action_selector: str = "first"):
"""
Initialize fixed action opponent.
Args:
action_selector: Which action to select ("first", "last", "middle").
"""
self.action_selector = action_selector
def select_action(self, legal_actions: list[int], observations: dict[str, Any]) -> int:
"""Select a fixed legal action based on selector."""
if not legal_actions:
raise ValueError("No legal actions available")
if self.action_selector == "first":
return legal_actions[0]
elif self.action_selector == "last":
return legal_actions[-1]
elif self.action_selector == "middle":
return legal_actions[len(legal_actions) // 2]
else:
return legal_actions[0]
def get_opponent_policy(policy_name: str) -> OpponentPolicy:
"""
Get an opponent policy by name.
Args:
policy_name: Name of the policy ("random", "first", "last", "middle").
Returns:
OpponentPolicy instance.
Raises:
ValueError: If policy_name is not recognized.
"""
if policy_name == "random":
return RandomOpponent()
elif policy_name in ("first", "last", "middle"):
return FixedActionOpponent(action_selector=policy_name)
else:
raise ValueError(f"Unknown opponent policy: {policy_name}")