File size: 8,095 Bytes
185e2d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
Connect4 Multi-Agent Environment — Server Side
Adapted for autonomous driving scenario:
  - Agent 1 = "Ego vehicle" (LLM being trained)
  - Agent 2 = "Opponent vehicle" (rule-based or another LLM)

The board represents a grid intersection control problem:
  - Winning = successfully navigating without collision
  - Rewards shaped for RL post-training
"""

import numpy as np
from typing import Optional
from openenv.core.environment import Environment
from ..models import (
    Connect4Action, Connect4Observation, Connect4State
)


ROWS = 6
COLS = 7
EMPTY = 0
AGENT1 = 1   # Ego vehicle / LLM under training
AGENT2 = 2   # Opponent / rule-based agent


class Connect4Environment(Environment):
    """
    Connect4 as a multi-agent driving coordination environment.

    Observation:
      - Board state (6x7 grid)
      - Current player turn
      - Legal moves
      - Last move played
      - Game status

    Reward shaping (for RL):
      +10.0  → Win (ego agent connects 4)
      -10.0  → Loss (opponent connects 4)
       +0.5  → Blocking opponent's winning move
       +0.2  → Creating a 3-in-a-row
       -0.1  → Invalid move attempt
        0.0  → Draw
    """

    def __init__(self):
        super().__init__()
        self.board: np.ndarray = np.zeros((ROWS, COLS), dtype=int)
        self.current_player: int = AGENT1
        self.done: bool = False
        self.winner: Optional[int] = None
        self.last_move: Optional[int] = None
        self.move_history: list = []

    # ------------------------------------------------------------------ #
    #  OpenEnv API                                                         #
    # ------------------------------------------------------------------ #

    def reset(self) -> Connect4Observation:
        self.board = np.zeros((ROWS, COLS), dtype=int)
        self.current_player = AGENT1
        self.done = False
        self.winner = None
        self.last_move = None
        self.move_history = []
        return self._make_observation("Game reset. Your turn — you are Player 1 (Ego Vehicle).")

    def step(self, action: Connect4Action) -> tuple[Connect4Observation, float, bool]:
        if self.done:
            obs = self._make_observation("Game already finished. Call reset() to start a new game.")
            return obs, 0.0, True

        col = action.column
        reward = 0.0

        # ---- validate move ----
        if col < 0 or col >= COLS or not self._is_valid(col):
            obs = self._make_observation(f"Invalid move: column {col} is full or out of range.")
            return obs, -0.1, False

        # ---- check for blocking bonus before placing ----
        reward += self._blocking_bonus(col)

        # ---- place piece ----
        row = self._drop_piece(col, self.current_player)
        self.last_move = col
        self.move_history.append((self.current_player, col))

        # ---- 3-in-a-row bonus ----
        if self._count_streak(row, col, self.current_player) >= 3:
            reward += 0.2

        # ---- check win ----
        if self._check_win(self.current_player):
            self.done = True
            self.winner = self.current_player
            reward += 10.0 if self.current_player == AGENT1 else -10.0
            msg = ("🏆 Ego vehicle wins! Successful navigation." 
                   if self.current_player == AGENT1 
                   else "💥 Opponent wins. Collision occurred.")
            obs = self._make_observation(msg)
            return obs, reward, True

        # ---- check draw ----
        if self._board_full():
            self.done = True
            obs = self._make_observation("🤝 Draw. Stalemate — no collision, no winner.")
            return obs, 0.0, True

        # ---- switch player ----
        self.current_player = AGENT2 if self.current_player == AGENT1 else AGENT1
        msg = f"Move accepted (col {col}). Now Player {self.current_player}'s turn."
        obs = self._make_observation(msg)
        return obs, reward, False

    def state(self) -> Connect4State:
        return Connect4State(
            episode_id=self._episode_id,
            step_count=self._step_count,
            current_player=self.current_player,
            done=self.done,
            winner=self.winner,
            move_history=self.move_history,
        )

    # ------------------------------------------------------------------ #
    #  Internal helpers                                                    #
    # ------------------------------------------------------------------ #

    def _make_observation(self, message: str) -> Connect4Observation:
        return Connect4Observation(
            board=self.board.tolist(),
            board_str=self._render_board(),
            current_player=self.current_player,
            legal_moves=self._legal_moves(),
            last_move=self.last_move,
            done=self.done,
            winner=self.winner,
            message=message,
        )

    def _render_board(self) -> str:
        symbols = {EMPTY: ".", AGENT1: "X", AGENT2: "O"}
        rows = []
        for r in range(ROWS):
            rows.append(" ".join(symbols[self.board[r][c]] for c in range(COLS)))
        rows.append("-" * (COLS * 2 - 1))
        rows.append(" ".join(str(c) for c in range(COLS)))
        return "\n".join(rows)

    def _is_valid(self, col: int) -> bool:
        return self.board[0][col] == EMPTY

    def _legal_moves(self) -> list[int]:
        return [c for c in range(COLS) if self._is_valid(c)]

    def _drop_piece(self, col: int, player: int) -> int:
        for row in range(ROWS - 1, -1, -1):
            if self.board[row][col] == EMPTY:
                self.board[row][col] = player
                return row
        return -1

    def _check_win(self, player: int) -> bool:
        b = self.board
        # Horizontal
        for r in range(ROWS):
            for c in range(COLS - 3):
                if all(b[r][c+i] == player for i in range(4)):
                    return True
        # Vertical
        for r in range(ROWS - 3):
            for c in range(COLS):
                if all(b[r+i][c] == player for i in range(4)):
                    return True
        # Diagonal /
        for r in range(3, ROWS):
            for c in range(COLS - 3):
                if all(b[r-i][c+i] == player for i in range(4)):
                    return True
        # Diagonal \
        for r in range(ROWS - 3):
            for c in range(COLS - 3):
                if all(b[r+i][c+i] == player for i in range(4)):
                    return True
        return False

    def _board_full(self) -> bool:
        return all(self.board[0][c] != EMPTY for c in range(COLS))

    def _count_streak(self, row: int, col: int, player: int) -> int:
        """Count max consecutive pieces for player around (row, col)."""
        directions = [(0,1),(1,0),(1,1),(1,-1)]
        best = 1
        for dr, dc in directions:
            count = 1
            for sign in [1, -1]:
                r, c = row + sign*dr, col + sign*dc
                while 0 <= r < ROWS and 0 <= c < COLS and self.board[r][c] == player:
                    count += 1
                    r += sign*dr
                    c += sign*dc
            best = max(best, count)
        return best

    def _blocking_bonus(self, col: int) -> float:
        """+0.5 if placing here blocks opponent's 4-in-a-row."""
        opponent = AGENT2 if self.current_player == AGENT1 else AGENT1
        test_board = self.board.copy()
        for row in range(ROWS - 1, -1, -1):
            if test_board[row][col] == EMPTY:
                test_board[row][col] = opponent
                break
        # Check if opponent would have won
        b = test_board
        for r in range(ROWS):
            for c in range(COLS - 3):
                if all(b[r][c+i] == opponent for i in range(4)):
                    return 0.5
        for r in range(ROWS - 3):
            for c in range(COLS):
                if all(b[r+i][c] == opponent for i in range(4)):
                    return 0.5
        return 0.0