""" Connect4 Multi-Agent Environment — Server Side Adapted for autonomous driving scenario: - Agent 1 = "Ego vehicle" (LLM being trained) - Agent 2 = "Opponent vehicle" (rule-based or another LLM) The board represents a grid intersection control problem: - Winning = successfully navigating without collision - Rewards shaped for RL post-training """ import numpy as np from typing import Optional from openenv.core.environment import Environment from ..models import ( Connect4Action, Connect4Observation, Connect4State ) ROWS = 6 COLS = 7 EMPTY = 0 AGENT1 = 1 # Ego vehicle / LLM under training AGENT2 = 2 # Opponent / rule-based agent class Connect4Environment(Environment): """ Connect4 as a multi-agent driving coordination environment. Observation: - Board state (6x7 grid) - Current player turn - Legal moves - Last move played - Game status Reward shaping (for RL): +10.0 → Win (ego agent connects 4) -10.0 → Loss (opponent connects 4) +0.5 → Blocking opponent's winning move +0.2 → Creating a 3-in-a-row -0.1 → Invalid move attempt 0.0 → Draw """ def __init__(self): super().__init__() self.board: np.ndarray = np.zeros((ROWS, COLS), dtype=int) self.current_player: int = AGENT1 self.done: bool = False self.winner: Optional[int] = None self.last_move: Optional[int] = None self.move_history: list = [] # ------------------------------------------------------------------ # # OpenEnv API # # ------------------------------------------------------------------ # def reset(self) -> Connect4Observation: self.board = np.zeros((ROWS, COLS), dtype=int) self.current_player = AGENT1 self.done = False self.winner = None self.last_move = None self.move_history = [] return self._make_observation("Game reset. Your turn — you are Player 1 (Ego Vehicle).") def step(self, action: Connect4Action) -> tuple[Connect4Observation, float, bool]: if self.done: obs = self._make_observation("Game already finished. Call reset() to start a new game.") return obs, 0.0, True col = action.column reward = 0.0 # ---- validate move ---- if col < 0 or col >= COLS or not self._is_valid(col): obs = self._make_observation(f"Invalid move: column {col} is full or out of range.") return obs, -0.1, False # ---- check for blocking bonus before placing ---- reward += self._blocking_bonus(col) # ---- place piece ---- row = self._drop_piece(col, self.current_player) self.last_move = col self.move_history.append((self.current_player, col)) # ---- 3-in-a-row bonus ---- if self._count_streak(row, col, self.current_player) >= 3: reward += 0.2 # ---- check win ---- if self._check_win(self.current_player): self.done = True self.winner = self.current_player reward += 10.0 if self.current_player == AGENT1 else -10.0 msg = ("🏆 Ego vehicle wins! Successful navigation." if self.current_player == AGENT1 else "💥 Opponent wins. Collision occurred.") obs = self._make_observation(msg) return obs, reward, True # ---- check draw ---- if self._board_full(): self.done = True obs = self._make_observation("🤝 Draw. Stalemate — no collision, no winner.") return obs, 0.0, True # ---- switch player ---- self.current_player = AGENT2 if self.current_player == AGENT1 else AGENT1 msg = f"Move accepted (col {col}). Now Player {self.current_player}'s turn." obs = self._make_observation(msg) return obs, reward, False def state(self) -> Connect4State: return Connect4State( episode_id=self._episode_id, step_count=self._step_count, current_player=self.current_player, done=self.done, winner=self.winner, move_history=self.move_history, ) # ------------------------------------------------------------------ # # Internal helpers # # ------------------------------------------------------------------ # def _make_observation(self, message: str) -> Connect4Observation: return Connect4Observation( board=self.board.tolist(), board_str=self._render_board(), current_player=self.current_player, legal_moves=self._legal_moves(), last_move=self.last_move, done=self.done, winner=self.winner, message=message, ) def _render_board(self) -> str: symbols = {EMPTY: ".", AGENT1: "X", AGENT2: "O"} rows = [] for r in range(ROWS): rows.append(" ".join(symbols[self.board[r][c]] for c in range(COLS))) rows.append("-" * (COLS * 2 - 1)) rows.append(" ".join(str(c) for c in range(COLS))) return "\n".join(rows) def _is_valid(self, col: int) -> bool: return self.board[0][col] == EMPTY def _legal_moves(self) -> list[int]: return [c for c in range(COLS) if self._is_valid(c)] def _drop_piece(self, col: int, player: int) -> int: for row in range(ROWS - 1, -1, -1): if self.board[row][col] == EMPTY: self.board[row][col] = player return row return -1 def _check_win(self, player: int) -> bool: b = self.board # Horizontal for r in range(ROWS): for c in range(COLS - 3): if all(b[r][c+i] == player for i in range(4)): return True # Vertical for r in range(ROWS - 3): for c in range(COLS): if all(b[r+i][c] == player for i in range(4)): return True # Diagonal / for r in range(3, ROWS): for c in range(COLS - 3): if all(b[r-i][c+i] == player for i in range(4)): return True # Diagonal \ for r in range(ROWS - 3): for c in range(COLS - 3): if all(b[r+i][c+i] == player for i in range(4)): return True return False def _board_full(self) -> bool: return all(self.board[0][c] != EMPTY for c in range(COLS)) def _count_streak(self, row: int, col: int, player: int) -> int: """Count max consecutive pieces for player around (row, col).""" directions = [(0,1),(1,0),(1,1),(1,-1)] best = 1 for dr, dc in directions: count = 1 for sign in [1, -1]: r, c = row + sign*dr, col + sign*dc while 0 <= r < ROWS and 0 <= c < COLS and self.board[r][c] == player: count += 1 r += sign*dr c += sign*dc best = max(best, count) return best def _blocking_bonus(self, col: int) -> float: """+0.5 if placing here blocks opponent's 4-in-a-row.""" opponent = AGENT2 if self.current_player == AGENT1 else AGENT1 test_board = self.board.copy() for row in range(ROWS - 1, -1, -1): if test_board[row][col] == EMPTY: test_board[row][col] = opponent break # Check if opponent would have won b = test_board for r in range(ROWS): for c in range(COLS - 3): if all(b[r][c+i] == opponent for i in range(4)): return 0.5 for r in range(ROWS - 3): for c in range(COLS): if all(b[r+i][c] == opponent for i in range(4)): return 0.5 return 0.0