Spaces:
Configuration error
Configuration error
File size: 5,023 Bytes
1b03c75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import random
class CustomStrategoEnv:
"""
Versiune simplificată de Stratego cu simboluri reale.
Tabla se poate redimensiona, iar piesele sunt distribuite automat.
"""
def __init__(self, env_id="Stratego-v0", board_size=10, **kwargs):
self.env_id = env_id
self.board_size = board_size
self.turn = 0
self.done = False
self.players = [0, 1]
self.symbols = ["A", "B"]
self.board = []
self.reset()
# ---------------------- BOARD SETUP ----------------------
def _generate_board(self):
n = self.board_size
board = [["." for _ in range(n)] for _ in range(n)]
# Lista de piese tipice Stratego
piece_types = ["BM", "FL", "MN", "SG", "LT", "CP", "MJ", "SP", "GN", "CL"]
# numărul de piese per jucător crește odată cu dimensiunea tablei
num_pieces = max(6, n * n // 6)
# selectăm piese aleatoriu, cu repetiție
p0_pieces = [random.choice(piece_types) for _ in range(num_pieces)]
p1_pieces = [random.choice(piece_types) for _ in range(num_pieces)]
# jumătate superioară — player 0
for i in range(num_pieces):
row = i // n
col = i % n
if row < n // 2:
board[row][col] = p0_pieces[i]
# jumătate inferioară — player 1
for i in range(num_pieces):
row = n - 1 - (i // n)
col = i % n
if row >= n // 2:
board[row][col] = p1_pieces[i]
# adăugăm câteva lacuri (~) dacă tabla e suficient de mare
if n >= 8:
for i in range(n // 3, n // 3 + 2):
for j in range(n // 3, n // 3 + 2):
board[i][j] = "~"
board[n - i - 1][n - j - 1] = "~"
return board
# ---------------------- API METHODS ----------------------
def reset(self, num_players=2):
self.turn = 0
self.done = False
self.board = self._generate_board()
return self.get_observation()
def get_observation(self):
player = self.turn % 2
board_text = "\n".join([" ".join(row) for row in self.board])
legal_moves = self._get_legal_moves(player)
obs = (
f"Player {player} ({self.symbols[player]}) turn.\n"
f"Board:\n{board_text}\n\n"
f"Legal moves:\n{', '.join(legal_moves)}"
)
return player, obs
def step(self, action):
moved = self._apply_move(action)
if not moved:
pass # dacă mutarea e invalidă, doar trecem rândul
# verificăm dacă un jucător mai are piese
half = self.board_size // 2
top_pieces = sum(cell not in [".", "~"] for row in self.board[:half] for cell in row)
bottom_pieces = sum(cell not in [".", "~"] for row in self.board[half:] for cell in row)
if top_pieces == 0 or bottom_pieces == 0:
self.done = True
self.turn += 1
return self.done, {}
def close(self):
rewards = {0: 0, 1: 0}
info = {"board_size": self.board_size}
return rewards, info
# ---------------------- MOVE LOGIC ----------------------
def _get_legal_moves(self, player):
n = self.board_size
moves = []
sym = self.symbols[player]
dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
for i in range(n):
for j in range(n):
cell = self.board[i][j]
if cell != "." and cell != "~": # piesă reală
for di, dj in dirs:
ni, nj = i + di, j + dj
if 0 <= ni < n and 0 <= nj < n:
if self.board[ni][nj] == ".":
move = f"{self._pos_to_label(i, j)} {self._pos_to_label(ni, nj)}"
moves.append(move)
random.shuffle(moves)
return moves[:10]
def _apply_move(self, action):
parts = action.strip().split()
if len(parts) != 2:
return False
src_label, dst_label = parts
si, sj = self._label_to_pos(src_label)
di, dj = self._label_to_pos(dst_label)
n = self.board_size
if not (0 <= si < n and 0 <= sj < n and 0 <= di < n and 0 <= dj < n):
return False
if self.board[si][sj] in [".", "~"]:
return False
self.board[di][dj] = self.board[si][sj]
self.board[si][sj] = "."
return True
# ---------------------- UTILS ----------------------
def _pos_to_label(self, i, j):
return f"{chr(65 + i)}{j}"
def _label_to_pos(self, label):
try:
row = ord(label[0].upper()) - 65
col = int(label[1:])
return row, col
except Exception:
return -1, -1
|