Spaces:
Build error
Build error
| """ | |
| Connect4 Multi-Agent Environment — Server Side | |
| Adapted for autonomous driving scenario: | |
| - Agent 1 = "Ego vehicle" (LLM being trained) | |
| - Agent 2 = "Opponent vehicle" (rule-based or another LLM) | |
| The board represents a grid intersection control problem: | |
| - Winning = successfully navigating without collision | |
| - Rewards shaped for RL post-training | |
| """ | |
| import numpy as np | |
| from typing import Optional | |
| from openenv.core.environment import Environment | |
| from ..models import ( | |
| Connect4Action, Connect4Observation, Connect4State | |
| ) | |
| ROWS = 6 | |
| COLS = 7 | |
| EMPTY = 0 | |
| AGENT1 = 1 # Ego vehicle / LLM under training | |
| AGENT2 = 2 # Opponent / rule-based agent | |
| class Connect4Environment(Environment): | |
| """ | |
| Connect4 as a multi-agent driving coordination environment. | |
| Observation: | |
| - Board state (6x7 grid) | |
| - Current player turn | |
| - Legal moves | |
| - Last move played | |
| - Game status | |
| Reward shaping (for RL): | |
| +10.0 → Win (ego agent connects 4) | |
| -10.0 → Loss (opponent connects 4) | |
| +0.5 → Blocking opponent's winning move | |
| +0.2 → Creating a 3-in-a-row | |
| -0.1 → Invalid move attempt | |
| 0.0 → Draw | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.board: np.ndarray = np.zeros((ROWS, COLS), dtype=int) | |
| self.current_player: int = AGENT1 | |
| self.done: bool = False | |
| self.winner: Optional[int] = None | |
| self.last_move: Optional[int] = None | |
| self.move_history: list = [] | |
| # ------------------------------------------------------------------ # | |
| # OpenEnv API # | |
| # ------------------------------------------------------------------ # | |
| def reset(self) -> Connect4Observation: | |
| self.board = np.zeros((ROWS, COLS), dtype=int) | |
| self.current_player = AGENT1 | |
| self.done = False | |
| self.winner = None | |
| self.last_move = None | |
| self.move_history = [] | |
| return self._make_observation("Game reset. Your turn — you are Player 1 (Ego Vehicle).") | |
| def step(self, action: Connect4Action) -> tuple[Connect4Observation, float, bool]: | |
| if self.done: | |
| obs = self._make_observation("Game already finished. Call reset() to start a new game.") | |
| return obs, 0.0, True | |
| col = action.column | |
| reward = 0.0 | |
| # ---- validate move ---- | |
| if col < 0 or col >= COLS or not self._is_valid(col): | |
| obs = self._make_observation(f"Invalid move: column {col} is full or out of range.") | |
| return obs, -0.1, False | |
| # ---- check for blocking bonus before placing ---- | |
| reward += self._blocking_bonus(col) | |
| # ---- place piece ---- | |
| row = self._drop_piece(col, self.current_player) | |
| self.last_move = col | |
| self.move_history.append((self.current_player, col)) | |
| # ---- 3-in-a-row bonus ---- | |
| if self._count_streak(row, col, self.current_player) >= 3: | |
| reward += 0.2 | |
| # ---- check win ---- | |
| if self._check_win(self.current_player): | |
| self.done = True | |
| self.winner = self.current_player | |
| reward += 10.0 if self.current_player == AGENT1 else -10.0 | |
| msg = ("🏆 Ego vehicle wins! Successful navigation." | |
| if self.current_player == AGENT1 | |
| else "💥 Opponent wins. Collision occurred.") | |
| obs = self._make_observation(msg) | |
| return obs, reward, True | |
| # ---- check draw ---- | |
| if self._board_full(): | |
| self.done = True | |
| obs = self._make_observation("🤝 Draw. Stalemate — no collision, no winner.") | |
| return obs, 0.0, True | |
| # ---- switch player ---- | |
| self.current_player = AGENT2 if self.current_player == AGENT1 else AGENT1 | |
| msg = f"Move accepted (col {col}). Now Player {self.current_player}'s turn." | |
| obs = self._make_observation(msg) | |
| return obs, reward, False | |
| def state(self) -> Connect4State: | |
| return Connect4State( | |
| episode_id=self._episode_id, | |
| step_count=self._step_count, | |
| current_player=self.current_player, | |
| done=self.done, | |
| winner=self.winner, | |
| move_history=self.move_history, | |
| ) | |
| # ------------------------------------------------------------------ # | |
| # Internal helpers # | |
| # ------------------------------------------------------------------ # | |
| def _make_observation(self, message: str) -> Connect4Observation: | |
| return Connect4Observation( | |
| board=self.board.tolist(), | |
| board_str=self._render_board(), | |
| current_player=self.current_player, | |
| legal_moves=self._legal_moves(), | |
| last_move=self.last_move, | |
| done=self.done, | |
| winner=self.winner, | |
| message=message, | |
| ) | |
| def _render_board(self) -> str: | |
| symbols = {EMPTY: ".", AGENT1: "X", AGENT2: "O"} | |
| rows = [] | |
| for r in range(ROWS): | |
| rows.append(" ".join(symbols[self.board[r][c]] for c in range(COLS))) | |
| rows.append("-" * (COLS * 2 - 1)) | |
| rows.append(" ".join(str(c) for c in range(COLS))) | |
| return "\n".join(rows) | |
| def _is_valid(self, col: int) -> bool: | |
| return self.board[0][col] == EMPTY | |
| def _legal_moves(self) -> list[int]: | |
| return [c for c in range(COLS) if self._is_valid(c)] | |
| def _drop_piece(self, col: int, player: int) -> int: | |
| for row in range(ROWS - 1, -1, -1): | |
| if self.board[row][col] == EMPTY: | |
| self.board[row][col] = player | |
| return row | |
| return -1 | |
| def _check_win(self, player: int) -> bool: | |
| b = self.board | |
| # Horizontal | |
| for r in range(ROWS): | |
| for c in range(COLS - 3): | |
| if all(b[r][c+i] == player for i in range(4)): | |
| return True | |
| # Vertical | |
| for r in range(ROWS - 3): | |
| for c in range(COLS): | |
| if all(b[r+i][c] == player for i in range(4)): | |
| return True | |
| # Diagonal / | |
| for r in range(3, ROWS): | |
| for c in range(COLS - 3): | |
| if all(b[r-i][c+i] == player for i in range(4)): | |
| return True | |
| # Diagonal \ | |
| for r in range(ROWS - 3): | |
| for c in range(COLS - 3): | |
| if all(b[r+i][c+i] == player for i in range(4)): | |
| return True | |
| return False | |
| def _board_full(self) -> bool: | |
| return all(self.board[0][c] != EMPTY for c in range(COLS)) | |
| def _count_streak(self, row: int, col: int, player: int) -> int: | |
| """Count max consecutive pieces for player around (row, col).""" | |
| directions = [(0,1),(1,0),(1,1),(1,-1)] | |
| best = 1 | |
| for dr, dc in directions: | |
| count = 1 | |
| for sign in [1, -1]: | |
| r, c = row + sign*dr, col + sign*dc | |
| while 0 <= r < ROWS and 0 <= c < COLS and self.board[r][c] == player: | |
| count += 1 | |
| r += sign*dr | |
| c += sign*dc | |
| best = max(best, count) | |
| return best | |
| def _blocking_bonus(self, col: int) -> float: | |
| """+0.5 if placing here blocks opponent's 4-in-a-row.""" | |
| opponent = AGENT2 if self.current_player == AGENT1 else AGENT1 | |
| test_board = self.board.copy() | |
| for row in range(ROWS - 1, -1, -1): | |
| if test_board[row][col] == EMPTY: | |
| test_board[row][col] = opponent | |
| break | |
| # Check if opponent would have won | |
| b = test_board | |
| for r in range(ROWS): | |
| for c in range(COLS - 3): | |
| if all(b[r][c+i] == opponent for i in range(4)): | |
| return 0.5 | |
| for r in range(ROWS - 3): | |
| for c in range(COLS): | |
| if all(b[r+i][c] == opponent for i in range(4)): | |
| return 0.5 | |
| return 0.0 | |