stratego-human-vs-ai / stratego /env /custom_env.py
DarshanScripts's picture
Upload stratego/env/custom_env.py with huggingface_hub
1b03c75 verified
import random
class CustomStrategoEnv:
"""
Versiune simplificată de Stratego cu simboluri reale.
Tabla se poate redimensiona, iar piesele sunt distribuite automat.
"""
def __init__(self, env_id="Stratego-v0", board_size=10, **kwargs):
self.env_id = env_id
self.board_size = board_size
self.turn = 0
self.done = False
self.players = [0, 1]
self.symbols = ["A", "B"]
self.board = []
self.reset()
# ---------------------- BOARD SETUP ----------------------
def _generate_board(self):
n = self.board_size
board = [["." for _ in range(n)] for _ in range(n)]
# Lista de piese tipice Stratego
piece_types = ["BM", "FL", "MN", "SG", "LT", "CP", "MJ", "SP", "GN", "CL"]
# numărul de piese per jucător crește odată cu dimensiunea tablei
num_pieces = max(6, n * n // 6)
# selectăm piese aleatoriu, cu repetiție
p0_pieces = [random.choice(piece_types) for _ in range(num_pieces)]
p1_pieces = [random.choice(piece_types) for _ in range(num_pieces)]
# jumătate superioară — player 0
for i in range(num_pieces):
row = i // n
col = i % n
if row < n // 2:
board[row][col] = p0_pieces[i]
# jumătate inferioară — player 1
for i in range(num_pieces):
row = n - 1 - (i // n)
col = i % n
if row >= n // 2:
board[row][col] = p1_pieces[i]
# adăugăm câteva lacuri (~) dacă tabla e suficient de mare
if n >= 8:
for i in range(n // 3, n // 3 + 2):
for j in range(n // 3, n // 3 + 2):
board[i][j] = "~"
board[n - i - 1][n - j - 1] = "~"
return board
# ---------------------- API METHODS ----------------------
def reset(self, num_players=2):
self.turn = 0
self.done = False
self.board = self._generate_board()
return self.get_observation()
def get_observation(self):
player = self.turn % 2
board_text = "\n".join([" ".join(row) for row in self.board])
legal_moves = self._get_legal_moves(player)
obs = (
f"Player {player} ({self.symbols[player]}) turn.\n"
f"Board:\n{board_text}\n\n"
f"Legal moves:\n{', '.join(legal_moves)}"
)
return player, obs
def step(self, action):
moved = self._apply_move(action)
if not moved:
pass # dacă mutarea e invalidă, doar trecem rândul
# verificăm dacă un jucător mai are piese
half = self.board_size // 2
top_pieces = sum(cell not in [".", "~"] for row in self.board[:half] for cell in row)
bottom_pieces = sum(cell not in [".", "~"] for row in self.board[half:] for cell in row)
if top_pieces == 0 or bottom_pieces == 0:
self.done = True
self.turn += 1
return self.done, {}
def close(self):
rewards = {0: 0, 1: 0}
info = {"board_size": self.board_size}
return rewards, info
# ---------------------- MOVE LOGIC ----------------------
def _get_legal_moves(self, player):
n = self.board_size
moves = []
sym = self.symbols[player]
dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
for i in range(n):
for j in range(n):
cell = self.board[i][j]
if cell != "." and cell != "~": # piesă reală
for di, dj in dirs:
ni, nj = i + di, j + dj
if 0 <= ni < n and 0 <= nj < n:
if self.board[ni][nj] == ".":
move = f"{self._pos_to_label(i, j)} {self._pos_to_label(ni, nj)}"
moves.append(move)
random.shuffle(moves)
return moves[:10]
def _apply_move(self, action):
parts = action.strip().split()
if len(parts) != 2:
return False
src_label, dst_label = parts
si, sj = self._label_to_pos(src_label)
di, dj = self._label_to_pos(dst_label)
n = self.board_size
if not (0 <= si < n and 0 <= sj < n and 0 <= di < n and 0 <= dj < n):
return False
if self.board[si][sj] in [".", "~"]:
return False
self.board[di][dj] = self.board[si][sj]
self.board[si][sj] = "."
return True
# ---------------------- UTILS ----------------------
def _pos_to_label(self, i, j):
return f"{chr(65 + i)}{j}"
def _label_to_pos(self, label):
try:
row = ord(label[0].upper()) - 65
col = int(label[1:])
return row, col
except Exception:
return -1, -1