File size: 5,023 Bytes
1b03c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import random

class CustomStrategoEnv:
    """

    Versiune simplificată de Stratego cu simboluri reale.

    Tabla se poate redimensiona, iar piesele sunt distribuite automat.

    """

    def __init__(self, env_id="Stratego-v0", board_size=10, **kwargs):
        self.env_id = env_id
        self.board_size = board_size
        self.turn = 0
        self.done = False
        self.players = [0, 1]
        self.symbols = ["A", "B"]
        self.board = []
        self.reset()

    # ---------------------- BOARD SETUP ----------------------
    def _generate_board(self):
        n = self.board_size
        board = [["." for _ in range(n)] for _ in range(n)]

        # Lista de piese tipice Stratego
        piece_types = ["BM", "FL", "MN", "SG", "LT", "CP", "MJ", "SP", "GN", "CL"]

        # numărul de piese per jucător crește odată cu dimensiunea tablei
        num_pieces = max(6, n * n // 6)

        # selectăm piese aleatoriu, cu repetiție
        p0_pieces = [random.choice(piece_types) for _ in range(num_pieces)]
        p1_pieces = [random.choice(piece_types) for _ in range(num_pieces)]

        # jumătate superioară — player 0
        for i in range(num_pieces):
            row = i // n
            col = i % n
            if row < n // 2:
                board[row][col] = p0_pieces[i]

        # jumătate inferioară — player 1
        for i in range(num_pieces):
            row = n - 1 - (i // n)
            col = i % n
            if row >= n // 2:
                board[row][col] = p1_pieces[i]

        # adăugăm câteva lacuri (~) dacă tabla e suficient de mare
        if n >= 8:
            for i in range(n // 3, n // 3 + 2):
                for j in range(n // 3, n // 3 + 2):
                    board[i][j] = "~"
                    board[n - i - 1][n - j - 1] = "~"

        return board

    # ---------------------- API METHODS ----------------------
    def reset(self, num_players=2):
        self.turn = 0
        self.done = False
        self.board = self._generate_board()
        return self.get_observation()

    def get_observation(self):
        player = self.turn % 2
        board_text = "\n".join([" ".join(row) for row in self.board])
        legal_moves = self._get_legal_moves(player)
        obs = (
            f"Player {player} ({self.symbols[player]}) turn.\n"
            f"Board:\n{board_text}\n\n"
            f"Legal moves:\n{', '.join(legal_moves)}"
        )
        return player, obs

    def step(self, action):
        moved = self._apply_move(action)
        if not moved:
            pass  # dacă mutarea e invalidă, doar trecem rândul

        # verificăm dacă un jucător mai are piese
        half = self.board_size // 2
        top_pieces = sum(cell not in [".", "~"] for row in self.board[:half] for cell in row)
        bottom_pieces = sum(cell not in [".", "~"] for row in self.board[half:] for cell in row)

        if top_pieces == 0 or bottom_pieces == 0:
            self.done = True

        self.turn += 1
        return self.done, {}

    def close(self):
        rewards = {0: 0, 1: 0}
        info = {"board_size": self.board_size}
        return rewards, info

    # ---------------------- MOVE LOGIC ----------------------
    def _get_legal_moves(self, player):
        n = self.board_size
        moves = []
        sym = self.symbols[player]
        dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]

        for i in range(n):
            for j in range(n):
                cell = self.board[i][j]
                if cell != "." and cell != "~":  # piesă reală
                    for di, dj in dirs:
                        ni, nj = i + di, j + dj
                        if 0 <= ni < n and 0 <= nj < n:
                            if self.board[ni][nj] == ".":
                                move = f"{self._pos_to_label(i, j)} {self._pos_to_label(ni, nj)}"
                                moves.append(move)
        random.shuffle(moves)
        return moves[:10]

    def _apply_move(self, action):
        parts = action.strip().split()
        if len(parts) != 2:
            return False
        src_label, dst_label = parts
        si, sj = self._label_to_pos(src_label)
        di, dj = self._label_to_pos(dst_label)

        n = self.board_size
        if not (0 <= si < n and 0 <= sj < n and 0 <= di < n and 0 <= dj < n):
            return False
        if self.board[si][sj] in [".", "~"]:
            return False

        self.board[di][dj] = self.board[si][sj]
        self.board[si][sj] = "."
        return True

    # ---------------------- UTILS ----------------------
    def _pos_to_label(self, i, j):
        return f"{chr(65 + i)}{j}"

    def _label_to_pos(self, label):
        try:
            row = ord(label[0].upper()) - 65
            col = int(label[1:])
            return row, col
        except Exception:
            return -1, -1