HackathonMarch2026 / connect4_environment.py
helshahaby's picture
Upload 6 files
185e2d2 verified
"""
Connect4 Multi-Agent Environment — Server Side
Adapted for autonomous driving scenario:
- Agent 1 = "Ego vehicle" (LLM being trained)
- Agent 2 = "Opponent vehicle" (rule-based or another LLM)
The board represents a grid intersection control problem:
- Winning = successfully navigating without collision
- Rewards shaped for RL post-training
"""
import numpy as np
from typing import Optional
from openenv.core.environment import Environment
from ..models import (
Connect4Action, Connect4Observation, Connect4State
)
ROWS = 6
COLS = 7
EMPTY = 0
AGENT1 = 1 # Ego vehicle / LLM under training
AGENT2 = 2 # Opponent / rule-based agent
class Connect4Environment(Environment):
"""
Connect4 as a multi-agent driving coordination environment.
Observation:
- Board state (6x7 grid)
- Current player turn
- Legal moves
- Last move played
- Game status
Reward shaping (for RL):
+10.0 → Win (ego agent connects 4)
-10.0 → Loss (opponent connects 4)
+0.5 → Blocking opponent's winning move
+0.2 → Creating a 3-in-a-row
-0.1 → Invalid move attempt
0.0 → Draw
"""
def __init__(self):
super().__init__()
self.board: np.ndarray = np.zeros((ROWS, COLS), dtype=int)
self.current_player: int = AGENT1
self.done: bool = False
self.winner: Optional[int] = None
self.last_move: Optional[int] = None
self.move_history: list = []
# ------------------------------------------------------------------ #
# OpenEnv API #
# ------------------------------------------------------------------ #
def reset(self) -> Connect4Observation:
self.board = np.zeros((ROWS, COLS), dtype=int)
self.current_player = AGENT1
self.done = False
self.winner = None
self.last_move = None
self.move_history = []
return self._make_observation("Game reset. Your turn — you are Player 1 (Ego Vehicle).")
def step(self, action: Connect4Action) -> tuple[Connect4Observation, float, bool]:
if self.done:
obs = self._make_observation("Game already finished. Call reset() to start a new game.")
return obs, 0.0, True
col = action.column
reward = 0.0
# ---- validate move ----
if col < 0 or col >= COLS or not self._is_valid(col):
obs = self._make_observation(f"Invalid move: column {col} is full or out of range.")
return obs, -0.1, False
# ---- check for blocking bonus before placing ----
reward += self._blocking_bonus(col)
# ---- place piece ----
row = self._drop_piece(col, self.current_player)
self.last_move = col
self.move_history.append((self.current_player, col))
# ---- 3-in-a-row bonus ----
if self._count_streak(row, col, self.current_player) >= 3:
reward += 0.2
# ---- check win ----
if self._check_win(self.current_player):
self.done = True
self.winner = self.current_player
reward += 10.0 if self.current_player == AGENT1 else -10.0
msg = ("🏆 Ego vehicle wins! Successful navigation."
if self.current_player == AGENT1
else "💥 Opponent wins. Collision occurred.")
obs = self._make_observation(msg)
return obs, reward, True
# ---- check draw ----
if self._board_full():
self.done = True
obs = self._make_observation("🤝 Draw. Stalemate — no collision, no winner.")
return obs, 0.0, True
# ---- switch player ----
self.current_player = AGENT2 if self.current_player == AGENT1 else AGENT1
msg = f"Move accepted (col {col}). Now Player {self.current_player}'s turn."
obs = self._make_observation(msg)
return obs, reward, False
def state(self) -> Connect4State:
return Connect4State(
episode_id=self._episode_id,
step_count=self._step_count,
current_player=self.current_player,
done=self.done,
winner=self.winner,
move_history=self.move_history,
)
# ------------------------------------------------------------------ #
# Internal helpers #
# ------------------------------------------------------------------ #
def _make_observation(self, message: str) -> Connect4Observation:
return Connect4Observation(
board=self.board.tolist(),
board_str=self._render_board(),
current_player=self.current_player,
legal_moves=self._legal_moves(),
last_move=self.last_move,
done=self.done,
winner=self.winner,
message=message,
)
def _render_board(self) -> str:
symbols = {EMPTY: ".", AGENT1: "X", AGENT2: "O"}
rows = []
for r in range(ROWS):
rows.append(" ".join(symbols[self.board[r][c]] for c in range(COLS)))
rows.append("-" * (COLS * 2 - 1))
rows.append(" ".join(str(c) for c in range(COLS)))
return "\n".join(rows)
def _is_valid(self, col: int) -> bool:
return self.board[0][col] == EMPTY
def _legal_moves(self) -> list[int]:
return [c for c in range(COLS) if self._is_valid(c)]
def _drop_piece(self, col: int, player: int) -> int:
for row in range(ROWS - 1, -1, -1):
if self.board[row][col] == EMPTY:
self.board[row][col] = player
return row
return -1
def _check_win(self, player: int) -> bool:
b = self.board
# Horizontal
for r in range(ROWS):
for c in range(COLS - 3):
if all(b[r][c+i] == player for i in range(4)):
return True
# Vertical
for r in range(ROWS - 3):
for c in range(COLS):
if all(b[r+i][c] == player for i in range(4)):
return True
# Diagonal /
for r in range(3, ROWS):
for c in range(COLS - 3):
if all(b[r-i][c+i] == player for i in range(4)):
return True
# Diagonal \
for r in range(ROWS - 3):
for c in range(COLS - 3):
if all(b[r+i][c+i] == player for i in range(4)):
return True
return False
def _board_full(self) -> bool:
return all(self.board[0][c] != EMPTY for c in range(COLS))
def _count_streak(self, row: int, col: int, player: int) -> int:
"""Count max consecutive pieces for player around (row, col)."""
directions = [(0,1),(1,0),(1,1),(1,-1)]
best = 1
for dr, dc in directions:
count = 1
for sign in [1, -1]:
r, c = row + sign*dr, col + sign*dc
while 0 <= r < ROWS and 0 <= c < COLS and self.board[r][c] == player:
count += 1
r += sign*dr
c += sign*dc
best = max(best, count)
return best
def _blocking_bonus(self, col: int) -> float:
"""+0.5 if placing here blocks opponent's 4-in-a-row."""
opponent = AGENT2 if self.current_player == AGENT1 else AGENT1
test_board = self.board.copy()
for row in range(ROWS - 1, -1, -1):
if test_board[row][col] == EMPTY:
test_board[row][col] = opponent
break
# Check if opponent would have won
b = test_board
for r in range(ROWS):
for c in range(COLS - 3):
if all(b[r][c+i] == opponent for i in range(4)):
return 0.5
for r in range(ROWS - 3):
for c in range(COLS):
if all(b[r+i][c] == opponent for i in range(4)):
return 0.5
return 0.0