Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Opponent policies for multi-player OpenSpiel games. | |
| These policies are used to control non-agent players in multi-player games, | |
| allowing single-agent RL training against fixed or adaptive opponents. | |
| """ | |
| import random | |
| from typing import Any, Protocol | |
| class OpponentPolicy(Protocol): | |
| """Protocol for opponent policies.""" | |
| def select_action(self, legal_actions: list[int], observations: dict[str, Any]) -> int: | |
| """ | |
| Select an action for the opponent. | |
| Args: | |
| legal_actions: List of legal action IDs. | |
| observations: Current observations from the environment. | |
| Returns: | |
| Selected action ID. | |
| """ | |
| ... | |
| class RandomOpponent: | |
| """Random opponent that selects uniformly from legal actions.""" | |
| def select_action(self, legal_actions: list[int], observations: dict[str, Any]) -> int: | |
| """Select a random legal action.""" | |
| if not legal_actions: | |
| raise ValueError("No legal actions available") | |
| return random.choice(legal_actions) | |
| class FixedActionOpponent: | |
| """Opponent that always selects the same action (e.g., first legal action).""" | |
| def __init__(self, action_selector: str = "first"): | |
| """ | |
| Initialize fixed action opponent. | |
| Args: | |
| action_selector: Which action to select ("first", "last", "middle"). | |
| """ | |
| self.action_selector = action_selector | |
| def select_action(self, legal_actions: list[int], observations: dict[str, Any]) -> int: | |
| """Select a fixed legal action based on selector.""" | |
| if not legal_actions: | |
| raise ValueError("No legal actions available") | |
| if self.action_selector == "first": | |
| return legal_actions[0] | |
| elif self.action_selector == "last": | |
| return legal_actions[-1] | |
| elif self.action_selector == "middle": | |
| return legal_actions[len(legal_actions) // 2] | |
| else: | |
| return legal_actions[0] | |
| def get_opponent_policy(policy_name: str) -> OpponentPolicy: | |
| """ | |
| Get an opponent policy by name. | |
| Args: | |
| policy_name: Name of the policy ("random", "first", "last", "middle"). | |
| Returns: | |
| OpponentPolicy instance. | |
| Raises: | |
| ValueError: If policy_name is not recognized. | |
| """ | |
| if policy_name == "random": | |
| return RandomOpponent() | |
| elif policy_name in ("first", "last", "middle"): | |
| return FixedActionOpponent(action_selector=policy_name) | |
| else: | |
| raise ValueError(f"Unknown opponent policy: {policy_name}") | |