| """ |
| Using Environments |
| ================== |
| |
| **Part 2 of 5** in the OpenEnv Getting Started Series |
| |
| This notebook covers how to use OpenEnv environments: connecting to them, |
| creating AI policies, running evaluations, and working with different games. |
| |
| .. note:: |
| **Time**: ~15 minutes | **Difficulty**: Beginner-Intermediate | **GPU Required**: No |
| |
| What You'll Learn |
| ----------------- |
| |
| - **Connection Methods**: Hub, Docker, and direct URL connections |
| - **Available Environments**: OpenSpiel games, coding, browsing, and more |
| - **Creating Policies**: Random, heuristic, and learning-based strategies |
| - **Running Evaluations**: Measuring and comparing policy performance |
| """ |
|
|
| |
| |
| |
| |
| |
|
|
| import random |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| import nest_asyncio |
| nest_asyncio.apply() |
|
|
| |
| try: |
| import google.colab |
|
|
| IN_COLAB = True |
| except ImportError: |
| IN_COLAB = False |
|
|
| if IN_COLAB: |
| print("=" * 70) |
| print(" GOOGLE COLAB DETECTED - Installing OpenEnv...") |
| print("=" * 70) |
|
|
| subprocess.run( |
| [sys.executable, "-m", "pip", "install", "-q", "openenv-core"], |
| capture_output=True, |
| ) |
| print(" OpenEnv installed!") |
| print("=" * 70) |
| else: |
| print("=" * 70) |
| print(" RUNNING LOCALLY") |
| print("=" * 70) |
|
|
| |
| src_path = Path.cwd().parent.parent.parent / "src" |
| if src_path.exists(): |
| sys.path.insert(0, str(src_path)) |
| envs_path = Path.cwd().parent.parent.parent / "envs" |
| if envs_path.exists(): |
| sys.path.insert(0, str(envs_path.parent)) |
|
|
| print("=" * 70) |
|
|
| print() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| print("=" * 70) |
| print(" CONNECTION METHODS") |
| print("=" * 70) |
|
|
| |
| try: |
| from openspiel_env.client import OpenSpielEnv |
| from openspiel_env.models import OpenSpielAction, OpenSpielObservation, OpenSpielState |
|
|
| IMPORTS_OK = True |
| print("✓ Imports successful") |
| except ImportError as e: |
| IMPORTS_OK = False |
| print(f"✗ Import error: {e}") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| print("\n" + "-" * 70) |
| print("METHOD 1: FROM HUGGING FACE HUB") |
| print("-" * 70) |
|
|
| if IMPORTS_OK: |
| import inspect |
|
|
| if hasattr(OpenSpielEnv, "from_hub"): |
| sig = inspect.signature(OpenSpielEnv.from_hub) |
| print(f"\nSignature: OpenSpielEnv.from_hub{sig}") |
|
|
| |
| if OpenSpielEnv.from_hub.__doc__: |
| doc_lines = OpenSpielEnv.from_hub.__doc__.strip().split("\n")[:3] |
| print(f"Purpose: {doc_lines[0].strip()}") |
| else: |
| print("\nfrom_hub method not available in this version") |
|
|
| print("\nUsage:") |
| print(" env = OpenSpielEnv.from_hub('openenv/openspiel-env')") |
| print("\nWhat happens:") |
| print(" 1. Pulls Docker image from HF registry") |
| print(" 2. Starts container on available port") |
| print(" 3. Connects via WebSocket") |
| print(" 4. Cleans up on close()") |
| else: |
| print("\n(OpenEnv not installed - showing expected signature)") |
| print("\nSignature: OpenSpielEnv.from_hub(repo_id, *, use_docker=True, ...)") |
|
|
| |
| |
| |
| |
| |
|
|
| print("\n" + "-" * 70) |
| print("METHOD 2: FROM DOCKER IMAGE") |
| print("-" * 70) |
|
|
| if IMPORTS_OK: |
| if hasattr(OpenSpielEnv, "from_docker_image"): |
| sig = inspect.signature(OpenSpielEnv.from_docker_image) |
| print(f"\nSignature: OpenSpielEnv.from_docker_image{sig}") |
|
|
| if OpenSpielEnv.from_docker_image.__doc__: |
| doc_lines = OpenSpielEnv.from_docker_image.__doc__.strip().split("\n")[:3] |
| print(f"Purpose: {doc_lines[0].strip()}") |
| else: |
| print("\nfrom_docker_image method not available in this version") |
|
|
| print("\nUsage:") |
| print(" # Build image first:") |
| print(" # docker build -t openspiel-env:latest ./envs/openspiel_env/server") |
| print(" env = OpenSpielEnv.from_docker_image('openspiel-env:latest')") |
| else: |
| print("\n(OpenEnv not installed - showing expected signature)") |
| print("\nSignature: OpenSpielEnv.from_docker_image(image, provider=None, ...)") |
|
|
| |
| |
| |
| |
| |
|
|
| print("\n" + "-" * 70) |
| print("METHOD 3: DIRECT URL CONNECTION") |
| print("-" * 70) |
|
|
| if IMPORTS_OK: |
| sig = inspect.signature(OpenSpielEnv.__init__) |
| print(f"\nSignature: OpenSpielEnv{sig}") |
| print("\nUsage:") |
| print(" # Start server first:") |
| print(" # docker run -p 8000:8000 openenv/openspiel-env:latest") |
| print(" env = OpenSpielEnv(base_url='http://localhost:8000')") |
| print("\nNote: Does NOT manage container lifecycle - you control the server") |
| else: |
| print("\n(OpenEnv not installed - showing expected signature)") |
| print("\nSignature: OpenSpielEnv(base_url, connect_timeout_s=10.0, ...)") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| print("\n" + "-" * 70) |
| print("CONTEXT MANAGER SUPPORT") |
| print("-" * 70) |
|
|
| if IMPORTS_OK: |
| has_enter = hasattr(OpenSpielEnv, "__enter__") |
| has_exit = hasattr(OpenSpielEnv, "__exit__") |
| print(f"\n__enter__ method: {'✓ Present' if has_enter else '✗ Missing'}") |
| print(f"__exit__ method: {'✓ Present' if has_exit else '✗ Missing'}") |
|
|
| if has_enter and has_exit: |
| print("\n✓ Context manager supported! Use with 'with' statement:") |
| print(" with OpenSpielEnv(base_url='...') as env:") |
| print(" result = env.reset()") |
| print(" # ... use env ...") |
| print(" # Automatically cleaned up") |
| else: |
| print("\n(OpenEnv not installed)") |
| print("Context managers are supported for automatic cleanup") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print("=" * 70) |
| print(" THE ENVIRONMENT LOOP - LIVE DEMO") |
| print("=" * 70) |
| print() |
|
|
| |
| GRID_HEIGHT = 10 |
| GRID_WIDTH = 5 |
|
|
| |
| class DemoObservation: |
| def __init__(self, info_state, legal_actions, done=False): |
| self.info_state = info_state |
| self.legal_actions = legal_actions |
| self.done = done |
|
|
| class DemoResult: |
| def __init__(self, observation, reward=0.0, done=False): |
| self.observation = observation |
| self.reward = reward |
| self.done = done |
|
|
| |
| ball_col = random.randint(0, GRID_WIDTH - 1) |
| paddle_col = GRID_WIDTH // 2 |
|
|
| print(f"Episode Starting:") |
| print(f" Ball column: {ball_col}") |
| print(f" Paddle column: {paddle_col}") |
| print() |
|
|
| |
| step_count = 0 |
| total_reward = 0.0 |
|
|
| print("Step | Ball Row | Paddle | Action | Info State (first 10)") |
| print("-" * 65) |
|
|
| for ball_row in range(GRID_HEIGHT): |
| |
| info_state = [0.0] * (GRID_HEIGHT * GRID_WIDTH) |
| info_state[ball_row * GRID_WIDTH + ball_col] = 1.0 |
| info_state[(GRID_HEIGHT - 1) * GRID_WIDTH + paddle_col] = 1.0 |
|
|
| obs = DemoObservation(info_state=info_state, legal_actions=[0, 1, 2]) |
|
|
| |
| if paddle_col < ball_col: |
| action_id = 2 |
| elif paddle_col > ball_col: |
| action_id = 0 |
| else: |
| action_id = 1 |
|
|
| action_names = {0: "LEFT", 1: "STAY", 2: "RIGHT"} |
|
|
| |
| info_preview = [f"{v:.0f}" for v in info_state[:10]] |
| print(f" {step_count:2d} | {ball_row:2d} | {paddle_col} | {action_names[action_id]:<5} | {info_preview}") |
|
|
| |
| if action_id == 0: |
| paddle_col = max(0, paddle_col - 1) |
| elif action_id == 2: |
| paddle_col = min(GRID_WIDTH - 1, paddle_col + 1) |
|
|
| step_count += 1 |
|
|
| |
| caught = (paddle_col == ball_col) |
| reward = 1.0 if caught else 0.0 |
|
|
| print("-" * 65) |
| print() |
| print(f"Episode Complete:") |
| print(f" Steps: {step_count}") |
| print(f" Ball landed at: column {ball_col}") |
| print(f" Paddle position: column {paddle_col}") |
| print(f" Reward: {reward}") |
| print(f" Result: {'CAUGHT! ✓' if caught else 'MISSED! ✗'}") |
| print() |
| print("This is the exact same loop you'd run with a live server,") |
| print("just using local simulation for the game logic.") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import random |
| from typing import List |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class PolicyResult: |
| """Result of evaluating a policy.""" |
|
|
| name: str |
| episodes: int |
| wins: int |
| total_reward: float |
| avg_steps: float |
|
|
| @property |
| def win_rate(self) -> float: |
| return self.wins / self.episodes if self.episodes > 0 else 0.0 |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class RandomPolicy: |
| """ |
| Random policy - baseline for comparison. |
| |
| Always picks a random action from the legal actions. |
| Expected win rate for Catch: ~20% (1 in 5 columns) |
| """ |
|
|
| name = "Random" |
|
|
| def choose_action(self, observation) -> int: |
| """Choose a random legal action.""" |
| return random.choice(observation.legal_actions) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class SmartCatchPolicy: |
| """ |
| Smart heuristic policy for the Catch game. |
| |
| Tracks the ball position and moves paddle toward it. |
| Expected win rate: ~100% (optimal for Catch) |
| """ |
|
|
| name = "Smart (Heuristic)" |
|
|
| def __init__(self, grid_width: int = 5): |
| self.grid_width = grid_width |
|
|
| def choose_action(self, observation) -> int: |
| """Move paddle toward ball position.""" |
| info_state = observation.info_state |
| grid_width = self.grid_width |
|
|
| |
| ball_col = None |
| for idx, val in enumerate(info_state[:-grid_width]): |
| if abs(val - 1.0) < 0.01: |
| ball_col = idx % grid_width |
| break |
|
|
| |
| last_row = info_state[-grid_width:] |
| paddle_col = None |
| for idx, val in enumerate(last_row): |
| if abs(val - 1.0) < 0.01: |
| paddle_col = idx |
| break |
|
|
| if ball_col is None or paddle_col is None: |
| return 1 |
|
|
| |
| if paddle_col < ball_col: |
| return 2 |
| elif paddle_col > ball_col: |
| return 0 |
| else: |
| return 1 |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class EpsilonGreedyPolicy: |
| """ |
| Epsilon-greedy policy - balances exploration and exploitation. |
| |
| With probability epsilon, takes random action (explore). |
| Otherwise, uses smart policy (exploit). |
| Epsilon decays over time to favor exploitation. |
| """ |
|
|
| name = "Epsilon-Greedy" |
|
|
| def __init__(self, epsilon: float = 0.3, decay: float = 0.99): |
| self.epsilon = epsilon |
| self.decay = decay |
| self.smart_policy = SmartCatchPolicy() |
| self.steps = 0 |
|
|
| def choose_action(self, observation) -> int: |
| """Choose action with epsilon-greedy strategy.""" |
| self.steps += 1 |
|
|
| |
| current_epsilon = self.epsilon * (self.decay**self.steps) |
|
|
| if random.random() < current_epsilon: |
| |
| return random.choice(observation.legal_actions) |
| else: |
| |
| return self.smart_policy.choose_action(observation) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class AlwaysStayPolicy: |
| """ |
| Always stay policy - deliberately bad baseline. |
| |
| Never moves the paddle. Only wins if ball lands on starting column. |
| Expected win rate: ~20% (same as random) |
| """ |
|
|
| name = "Always Stay" |
|
|
| def choose_action(self, observation) -> int: |
| """Always return STAY action.""" |
| return 1 |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| def evaluate_policy_live( |
| policy, |
| env, |
| num_episodes: int = 50, |
| game_name: str = "catch", |
| ) -> PolicyResult: |
| """ |
| Evaluate a policy against a live environment. |
| |
| Args: |
| policy: Policy object with choose_action method |
| env: Connected OpenSpielEnv client |
| num_episodes: Number of episodes to run |
| game_name: Name of the game to play |
| |
| Returns: |
| PolicyResult with evaluation metrics |
| """ |
| wins = 0 |
| total_reward = 0.0 |
| total_steps = 0 |
|
|
| for _ in range(num_episodes): |
| result = env.reset() |
| episode_steps = 0 |
|
|
| while not result.done: |
| action_id = policy.choose_action(result.observation) |
| action = OpenSpielAction(action_id=action_id, game_name=game_name) |
| result = env.step(action) |
| episode_steps += 1 |
|
|
| total_reward += result.reward if result.reward else 0 |
| total_steps += episode_steps |
| if result.reward and result.reward > 0: |
| wins += 1 |
|
|
| return PolicyResult( |
| name=policy.name, |
| episodes=num_episodes, |
| wins=wins, |
| total_reward=total_reward, |
| avg_steps=total_steps / num_episodes, |
| ) |
|
|
|
|
| def evaluate_policy_simulated( |
| policy, |
| num_episodes: int = 50, |
| grid_height: int = 10, |
| grid_width: int = 5, |
| ) -> PolicyResult: |
| """ |
| Evaluate a policy using local simulation (no server needed). |
| |
| This simulates the Catch game locally for testing without a server. |
| |
| Args: |
| policy: Policy object with choose_action method |
| num_episodes: Number of episodes to run |
| grid_height: Height of the game grid |
| grid_width: Width of the game grid |
| |
| Returns: |
| PolicyResult with evaluation metrics |
| """ |
| wins = 0 |
| total_reward = 0.0 |
| total_steps = 0 |
|
|
| |
| class MockObservation: |
| def __init__(self, info_state, legal_actions): |
| self.info_state = info_state |
| self.legal_actions = legal_actions |
|
|
| for _ in range(num_episodes): |
| |
| ball_col = random.randint(0, grid_width - 1) |
| paddle_col = grid_width // 2 |
|
|
| for step in range(grid_height): |
| |
| info_state = [0.0] * (grid_height * grid_width) |
| info_state[step * grid_width + ball_col] = 1.0 |
| info_state[(grid_height - 1) * grid_width + paddle_col] = 1.0 |
|
|
| observation = MockObservation( |
| info_state=info_state, legal_actions=[0, 1, 2] |
| ) |
|
|
| |
| action = policy.choose_action(observation) |
|
|
| |
| if action == 0: |
| paddle_col = max(0, paddle_col - 1) |
| elif action == 2: |
| paddle_col = min(grid_width - 1, paddle_col + 1) |
| |
|
|
| total_steps += 1 |
|
|
| |
| if paddle_col == ball_col: |
| wins += 1 |
| total_reward += 1.0 |
|
|
| return PolicyResult( |
| name=policy.name, |
| episodes=num_episodes, |
| wins=wins, |
| total_reward=total_reward, |
| avg_steps=total_steps / num_episodes, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| policies = [ |
| RandomPolicy(), |
| AlwaysStayPolicy(), |
| SmartCatchPolicy(), |
| EpsilonGreedyPolicy(epsilon=0.3), |
| ] |
|
|
| |
| SERVER_URL = "http://localhost:8000" |
| USE_LIVE = False |
|
|
| if IMPORTS_OK: |
| try: |
| test_env = OpenSpielEnv(base_url=SERVER_URL) |
| with test_env.sync() as client: |
| pass |
| USE_LIVE = True |
| print(f"✓ Connected to server at {SERVER_URL}") |
| except Exception as e: |
| USE_LIVE = False |
| print(f"✗ No server running at {SERVER_URL}: {e}") |
|
|
| print("=" * 70) |
| if USE_LIVE: |
| print(" POLICY COMPETITION - LIVE SERVER") |
| else: |
| print(" POLICY COMPETITION - SIMULATION MODE") |
| print("=" * 70) |
| print() |
|
|
| NUM_EPISODES = 50 |
| print(f"Running {NUM_EPISODES} episodes per policy...\n") |
|
|
| results = [] |
|
|
| for policy in policies: |
| print(f" Evaluating {policy.name}...", end=" ", flush=True) |
|
|
| if USE_LIVE: |
| env = OpenSpielEnv(base_url=SERVER_URL) |
| with env.sync() as client: |
| result = evaluate_policy_live(policy, client, NUM_EPISODES) |
| else: |
| result = evaluate_policy_simulated(policy, NUM_EPISODES) |
|
|
| results.append(result) |
| print(f"Win rate: {result.win_rate * 100:.1f}%") |
|
|
| |
| |
| |
|
|
| print() |
| print("=" * 70) |
| print(" FINAL RESULTS") |
| print("=" * 70) |
| print() |
|
|
| |
| results.sort(key=lambda r: r.win_rate, reverse=True) |
|
|
| |
| print(f"{'Rank':<6}{'Policy':<20}{'Win Rate':<12}{'Avg Steps':<12}{'Wins'}") |
| print("-" * 60) |
|
|
| for i, result in enumerate(results): |
| rank = f"#{i + 1}" |
| bar = "█" * int(result.win_rate * 20) |
| print( |
| f"{rank:<6}{result.name:<20}{result.win_rate * 100:>5.1f}%{'':<5}" |
| f"{result.avg_steps:>6.1f}{'':<6}{result.wins}/{result.episodes}" |
| ) |
|
|
| print() |
| print("-" * 70) |
| print() |
| print("Key Insights:") |
| print(" • Random/AlwaysStay: ~20% (baseline - relies on luck)") |
| print(" • Smart Heuristic: ~100% (optimal for Catch)") |
| print(" • Epsilon-Greedy: ~85%+ (balances exploration/exploitation)") |
| print() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| print("=" * 70) |
| print(" SWITCHING GAMES - ACTUAL ACTION INSTANCES") |
| print("=" * 70) |
| print() |
|
|
| |
| if IMPORTS_OK: |
| from openspiel_env.models import OpenSpielAction as ActionModel |
|
|
| |
| print("CATCH GAME ACTIONS:") |
| print("-" * 40) |
| catch_actions = { |
| 0: "Move LEFT", |
| 1: "STAY in place", |
| 2: "Move RIGHT", |
| } |
| for action_id, description in catch_actions.items(): |
| action = ActionModel(action_id=action_id, game_name="catch") |
| print(f" {action} # {description}") |
|
|
| print() |
|
|
| |
| print("2048 GAME ACTIONS:") |
| print("-" * 40) |
| game_2048_actions = { |
| 0: "Slide UP", |
| 1: "Slide RIGHT", |
| 2: "Slide DOWN", |
| 3: "Slide LEFT", |
| } |
| for action_id, description in game_2048_actions.items(): |
| action = ActionModel(action_id=action_id, game_name="2048") |
| print(f" {action} # {description}") |
|
|
| print() |
|
|
| |
| print("TIC-TAC-TOE ACTIONS:") |
| print("-" * 40) |
| print(" Grid positions 0-8 (left-to-right, top-to-bottom):") |
| print(" 0 | 1 | 2") |
| print(" ---|---|---") |
| print(" 3 | 4 | 5") |
| print(" ---|---|---") |
| print(" 6 | 7 | 8") |
| print() |
| |
| for pos in [0, 4, 8]: |
| action = ActionModel(action_id=pos, game_name="tic_tac_toe") |
| corner = {0: "top-left", 4: "center", 8: "bottom-right"}[pos] |
| print(f" {action} # {corner}") |
|
|
| print() |
|
|
| |
| print("BLACKJACK ACTIONS:") |
| print("-" * 40) |
| blackjack_actions = { |
| 0: "STAND (keep current hand)", |
| 1: "HIT (request another card)", |
| } |
| for action_id, description in blackjack_actions.items(): |
| action = ActionModel(action_id=action_id, game_name="blackjack") |
| print(f" {action} # {description}") |
|
|
| else: |
| |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class ActionDemo: |
| action_id: int |
| game_name: str |
|
|
| print("CATCH GAME ACTIONS:") |
| print("-" * 40) |
| for action_id, desc in [(0, "LEFT"), (1, "STAY"), (2, "RIGHT")]: |
| print(f" ActionDemo(action_id={action_id}, game_name='catch') # {desc}") |
|
|
| print() |
| print("2048 GAME ACTIONS:") |
| print("-" * 40) |
| for action_id, desc in [(0, "UP"), (1, "RIGHT"), (2, "DOWN"), (3, "LEFT")]: |
| print(f" ActionDemo(action_id={action_id}, game_name='2048') # {desc}") |
|
|
| print() |
| print("-" * 70) |
| print("Each game has its own action space - check legal_actions in observation!") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| print("=" * 70) |
| print(" MULTI-PLAYER GAMES - OBSERVATION STRUCTURE") |
| print("=" * 70) |
| print() |
|
|
| |
| if IMPORTS_OK: |
| from openspiel_env.models import OpenSpielObservation as ObsModel |
|
|
| |
| print("SINGLE-PLAYER OBSERVATION (Catch):") |
| print("-" * 50) |
| single_player_obs = ObsModel( |
| info_state=[0.0, 0.0, 1.0, 0.0, 0.0] + [0.0] * 45, |
| legal_actions=[0, 1, 2], |
| game_phase="playing", |
| current_player_id=0, |
| opponent_last_action=None, |
| ) |
| print(f" current_player_id: {single_player_obs.current_player_id} # Always 0 (you)") |
| print(f" opponent_last_action: {single_player_obs.opponent_last_action} # None (no opponent)") |
| print(f" legal_actions: {single_player_obs.legal_actions}") |
| print(f" game_phase: {single_player_obs.game_phase!r}") |
| print() |
|
|
| |
| print("MULTI-PLAYER OBSERVATION (Tic-Tac-Toe, YOUR turn):") |
| print("-" * 50) |
| your_turn_obs = ObsModel( |
| info_state=[1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0], |
| legal_actions=[1, 2, 3, 5, 6, 7, 8], |
| game_phase="playing", |
| current_player_id=0, |
| opponent_last_action=4, |
| ) |
| print(f" current_player_id: {your_turn_obs.current_player_id} # 0 = YOUR turn") |
| print(f" opponent_last_action: {your_turn_obs.opponent_last_action} # Opponent played position 4 (center)") |
| print(f" legal_actions: {your_turn_obs.legal_actions}") |
| print(f" game_phase: {your_turn_obs.game_phase!r}") |
| print() |
|
|
| |
| print("MULTI-PLAYER OBSERVATION (Tic-Tac-Toe, OPPONENT's turn):") |
| print("-" * 50) |
| opponent_turn_obs = ObsModel( |
| info_state=[1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0], |
| legal_actions=[], |
| game_phase="playing", |
| current_player_id=1, |
| opponent_last_action=None, |
| ) |
| print(f" current_player_id: {opponent_turn_obs.current_player_id} # 1 = OPPONENT's turn") |
| print(f" legal_actions: {opponent_turn_obs.legal_actions} # Empty - wait for opponent") |
| print(f" game_phase: {opponent_turn_obs.game_phase!r}") |
| print() |
|
|
| |
| print("TERMINAL OBSERVATION (Game Over):") |
| print("-" * 50) |
| terminal_obs = ObsModel( |
| info_state=[1.0, 1.0, 1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0], |
| legal_actions=[], |
| game_phase="terminal", |
| current_player_id=-1, |
| opponent_last_action=4, |
| ) |
| print(f" current_player_id: {terminal_obs.current_player_id} # -1 = Game over") |
| print(f" game_phase: {terminal_obs.game_phase!r}") |
| print(f" legal_actions: {terminal_obs.legal_actions} # Empty - game ended") |
|
|
| else: |
| |
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
| @dataclass |
| class ObsDemo: |
| current_player_id: int |
| opponent_last_action: Optional[int] |
| legal_actions: List[int] |
| game_phase: str |
|
|
| print("SINGLE-PLAYER (Catch):") |
| print(f" current_player_id: 0 # Always your turn") |
| print(f" opponent_last_action: None") |
| print() |
|
|
| print("MULTI-PLAYER - YOUR TURN (Tic-Tac-Toe):") |
| print(f" current_player_id: 0 # 0 = your turn") |
| print(f" opponent_last_action: 4 # What opponent just played") |
| print(f" legal_actions: [1, 2, 3, 5, 6, 7, 8] # Available moves") |
| print() |
|
|
| print("MULTI-PLAYER - OPPONENT'S TURN:") |
| print(f" current_player_id: 1 # Wait for opponent") |
| print(f" legal_actions: [] # Can't move during opponent's turn") |
|
|
| print() |
| print("-" * 70) |
| print("KEY INSIGHT: Only act when current_player_id == 0 (your turn)!") |
| print("The environment automatically handles opponent moves.") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|