Madras1 commited on
Commit
87ac2de
·
verified ·
1 Parent(s): e2a985b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import os
7
+ import time
8
+
9
+ # --- CONFIGURAÇÕES ---
10
+ BOARD_SIZE = 8
11
+ DEVICE = torch.device("cpu")
12
+ MODEL_PATH = "checkers_master_final.pth" # Certifique-se de que este arquivo está no Space!
13
+
14
+ # --- DEFINIÇÃO DAS CLASSES (Rede Neural e Jogo) ---
15
+ # A Berta copiou a lógica exata do seu script para garantir que funcione igual.
16
+
17
+ class Checkers:
18
+ def get_initial_board(self):
19
+ board = np.zeros((BOARD_SIZE, BOARD_SIZE), dtype=np.int8)
20
+ for r in range(3):
21
+ for c in range(BOARD_SIZE):
22
+ if (r + c) % 2 == 1: board[r, c] = -1
23
+ for r in range(5, BOARD_SIZE):
24
+ for c in range(BOARD_SIZE):
25
+ if (r + c) % 2 == 1: board[r, c] = 1
26
+ return board
27
+
28
+ def get_valid_moves(self, board, player):
29
+ jumps = self._get_all_jumps(board, player)
30
+ if jumps: return jumps
31
+ moves = []
32
+ for r in range(BOARD_SIZE):
33
+ for c in range(BOARD_SIZE):
34
+ if board[r, c] * player > 0: moves.extend(self._get_simple_moves(board, r, c))
35
+ return moves
36
+
37
+ def _get_simple_moves(self, board, r, c):
38
+ moves = []; piece = board[r, c]; player = np.sign(piece)
39
+ directions = [(-1, -1), (-1, 1)] if player == 1 else [(1, -1), (1, 1)]
40
+ if abs(piece) == 2: directions.extend([(1, -1), (1, 1)] if player == 1 else [(-1, -1), (-1, 1)])
41
+ for dr, dc in directions:
42
+ nr, nc = r + dr, c + dc
43
+ if 0 <= nr < BOARD_SIZE and 0 <= nc < BOARD_SIZE and board[nr, nc] == 0: moves.append(((r, c), (nr, nc)))
44
+ return moves
45
+
46
+ def _get_all_jumps(self, board, player):
47
+ all_jumps = []
48
+ for r in range(BOARD_SIZE):
49
+ for c in range(BOARD_SIZE):
50
+ if board[r, c] * player > 0:
51
+ jumps = self._find_jump_sequences(np.copy(board), r, c)
52
+ if jumps: all_jumps.extend(jumps)
53
+ if not all_jumps: return []
54
+ max_len = max(len(j) for j in all_jumps)
55
+ return [j for j in all_jumps if len(j) == max_len]
56
+
57
+ def _find_jump_sequences(self, board, r, c, path=[]):
58
+ piece = board[r, c]; player = np.sign(piece)
59
+ if piece == 0: return []
60
+ directions = [(-1, -1), (-1, 1), (1, -1), (1, 1)] if abs(piece) == 2 else \
61
+ [(-1, -1), (-1, 1)] if player == 1 else [(1, -1), (1, 1)]
62
+ found_jumps = []
63
+ for dr, dc in directions:
64
+ mid_r, mid_c = r + dr, c + dc; end_r, end_c = r + 2*dr, c + 2*dc
65
+ if 0 <= end_r < BOARD_SIZE and 0 <= end_c < BOARD_SIZE and \
66
+ board[mid_r, mid_c] * player < 0 and board[end_r, end_c] == 0:
67
+ move = ((r, c), (end_r, end_c))
68
+ new_board = np.copy(board); new_board[end_r, end_c] = piece; new_board[r, c] = 0; new_board[mid_r, mid_c] = 0
69
+ next_jumps = self._find_jump_sequences(new_board, end_r, end_c, path + [move])
70
+ if next_jumps: found_jumps.extend(next_jumps)
71
+ else: found_jumps.append(path + [move])
72
+ return found_jumps
73
+
74
+ def apply_move(self, board, move):
75
+ b_ = np.copy(board)
76
+ is_jump_chain = isinstance(move, list) or (isinstance(move, tuple) and isinstance(move[0], tuple) and isinstance(move[0][0], tuple))
77
+ sub_moves = move if is_jump_chain else [move]
78
+ for (r1, c1), (r2, c2) in sub_moves:
79
+ piece = b_[r1, c1]
80
+ if piece == 0: continue
81
+ b_[r2, c2] = piece; b_[r1, c1] = 0
82
+ if abs(r1 - r2) == 2: b_[(r1+r2)//2, (c1+c2)//2] = 0
83
+ r_final, c_final = sub_moves[-1][1]; p_final = b_[r_final, c_final]
84
+ if p_final == 1 and r_final == 0: b_[r_final, c_final] = 2
85
+ if p_final == -1 and r_final == BOARD_SIZE-1: b_[r_final, c_final] = -2
86
+ return b_
87
+
88
+ def check_game_over(self, board, player):
89
+ if not self.get_valid_moves(board, player): return -1
90
+ if not np.any(np.sign(board) == -player): return 1
91
+ return None
92
+
93
+ def state_to_tensor(board, player):
94
+ tensor = np.zeros((5, BOARD_SIZE, BOARD_SIZE), dtype=np.float32)
95
+ tensor[0, board == player] = 1; tensor[1, board == player*2] = 1
96
+ tensor[2, board == -player] = 1; tensor[3, board == -player*2] = 1
97
+ if player == 1: tensor[4,:,:] = 1.0
98
+ return torch.from_numpy(tensor).unsqueeze(0).to(DEVICE)
99
+
100
+ class PolicyValueNetwork(nn.Module):
101
+ def __init__(self):
102
+ super().__init__()
103
+ num_channels = 64
104
+ self.body = nn.Sequential(nn.Conv2d(5, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU(),
105
+ nn.Conv2d(num_channels, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU(),
106
+ nn.Conv2d(num_channels, num_channels, 3, padding=1), nn.BatchNorm2d(num_channels), nn.ReLU())
107
+ self.policy_head = nn.Sequential(nn.Conv2d(num_channels, 4, 1), nn.BatchNorm2d(4), nn.ReLU(), nn.Flatten(),
108
+ nn.Linear(4 * BOARD_SIZE * BOARD_SIZE, BOARD_SIZE * BOARD_SIZE))
109
+ self.value_head = nn.Sequential(nn.Conv2d(num_channels, 2, 1), nn.BatchNorm2d(2), nn.ReLU(), nn.Flatten(),
110
+ nn.Linear(2 * BOARD_SIZE * BOARD_SIZE, 64), nn.ReLU(),
111
+ nn.Linear(64, 1), nn.Tanh())
112
+ def forward(self, x):
113
+ x = self.body(x); return self.policy_head(x), self.value_head(x)
114
+
115
+ class MCTSNode:
116
+ def __init__(self, parent=None, prior=0.0):
117
+ self.parent = parent; self.prior = prior; self.children = {}; self.visits = 0; self.value_sum = 0.0
118
+ def get_value(self): return self.value_sum / self.visits if self.visits > 0 else 0.0
119
+
120
+ class MCTS:
121
+ def __init__(self, game, model, sims=100, c_puct=1.5):
122
+ self.game, self.model, self.sims, self.c_puct = game, model, sims, c_puct
123
+ def run(self, board, player):
124
+ root = MCTSNode()
125
+ self._expand_and_evaluate(root, board, player)
126
+ for _ in range(self.sims):
127
+ node, search_board, search_player = root, np.copy(board), player
128
+ search_path = [root]
129
+ while node.children:
130
+ move, node = self._select_child(node)
131
+ search_board = self.game.apply_move(search_board, move); search_player *= -1; search_path.append(node)
132
+ value = self.game.check_game_over(search_board, search_player)
133
+ if value is None and node.visits == 0: value = self._expand_and_evaluate(node, search_board, search_player)
134
+ elif value is None: value = node.get_value()
135
+ for n in reversed(search_path): n.visits += 1; n.value_sum += value; value *= -1
136
+ moves = list(root.children.keys())
137
+ visits = np.array([root.children[m].visits for m in moves])
138
+ return moves, visits / (np.sum(visits) + 1e-9)
139
+ def _select_child(self, node):
140
+ sqrt_total_visits = np.sqrt(node.visits); best_move, max_score = None, -np.inf
141
+ for move, child in node.children.items():
142
+ score = -child.get_value() + self.c_puct * child.prior * sqrt_total_visits / (1 + child.visits)
143
+ if score > max_score: max_score, best_move = score, move
144
+ return best_move, node.children[best_move]
145
+ def _expand_and_evaluate(self, node, board, player):
146
+ valid_moves = self.game.get_valid_moves(board, player)
147
+ if not valid_moves: return -1.0
148
+ with torch.no_grad():
149
+ policy_logits, value_tensor = self.model(state_to_tensor(board, player))
150
+ value = value_tensor.item()
151
+ policy_probs = F.softmax(policy_logits, dim=1).cpu().numpy()[0]
152
+ move_priors = {}; total_prior = 0
153
+ for move in valid_moves:
154
+ if isinstance(move, list): start_pos_tuple = move[0][0]
155
+ else: start_pos_tuple = move[0]
156
+ start_pos_idx = start_pos_tuple[0] * BOARD_SIZE + start_pos_tuple[1]
157
+ prior = policy_probs[start_pos_idx]
158
+ key = tuple(move) if isinstance(move, list) else move
159
+ move_priors[key] = prior; total_prior += prior
160
+ if total_prior > 0:
161
+ for move_key, prior in move_priors.items(): node.children[move_key] = MCTSNode(parent=node, prior=prior / total_prior)
162
+ else:
163
+ for move in valid_moves:
164
+ key = tuple(move) if isinstance(move, list) else move
165
+ node.children[key] = MCTSNode(parent=node, prior=1.0 / len(valid_moves))
166
+ return value
167
+
168
+ # --- INTERFACE DO STREAMLIT ---
169
+
170
+ st.set_page_config(page_title="AlphaCheckerZero", page_icon="♟️")
171
+
172
+ st.title("♟️ AlphaCheckerZero Arena")
173
+ st.write("Gabriel Yogi's Neural Network AI")
174
+
175
+ # 1. Carregar o Modelo (com Cache para ser rápido)
176
+ @st.cache_resource
177
+ def load_model():
178
+ if not os.path.exists(MODEL_PATH):
179
+ return None
180
+ model = PolicyValueNetwork().to(DEVICE)
181
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
182
+ model.eval()
183
+ return model
184
+
185
+ model = load_model()
186
+
187
+ if model is None:
188
+ st.error(f"Arquivo '{MODEL_PATH}' não encontrado. Por favor, faça upload do arquivo .pth para o Space.")
189
+ st.stop()
190
+
191
+ # 2. Inicializar o Estado do Jogo
192
+ if "board" not in st.session_state:
193
+ game = Checkers()
194
+ st.session_state.board = game.get_initial_board()
195
+ st.session_state.player = 1 # Humano começa (1)
196
+ st.session_state.game_over = False
197
+ st.session_state.message = "Sua vez! Você joga com as Brancas (x)."
198
+
199
+ game = Checkers()
200
+ mcts = MCTS(game, model, sims=150) # Sims ajustado para performance na web
201
+
202
+ # Função para desenhar o tabuleiro como texto (simples e funcional)
203
+ def render_board(board):
204
+ chars = {1: 'x', 2: 'X', -1: 'o', -2: 'O', 0: '.'}
205
+ board_str = " 0 1 2 3 4 5 6 7\n"
206
+ board_str += " -----------------\n"
207
+ for r_idx, row in enumerate(board):
208
+ board_str += f"{r_idx} | {' '.join(chars[val] for val in row)} |\n"
209
+ board_str += " -----------------"
210
+ return board_str
211
+
212
+ # Layout principal
213
+ col1, col2 = st.columns([2, 1])
214
+
215
+ with col1:
216
+ st.text_area("Tabuleiro", render_board(st.session_state.board), height=250, disabled=True, key="board_display")
217
+
218
+ with col2:
219
+ st.write("### Status")
220
+ st.info(st.session_state.message)
221
+
222
+ if st.button("Reiniciar Jogo"):
223
+ st.session_state.board = game.get_initial_board()
224
+ st.session_state.player = 1
225
+ st.session_state.game_over = False
226
+ st.session_state.message = "Jogo reiniciado. Sua vez!"
227
+ st.rerun()
228
+
229
+ # Lógica do Jogo
230
+ if not st.session_state.game_over:
231
+ # Verificar fim de jogo antes de qualquer coisa
232
+ result = game.check_game_over(st.session_state.board, st.session_state.player)
233
+ if result is not None:
234
+ st.session_state.game_over = True
235
+ if result == 1: st.session_state.message = "VOCÊ GANHOU! Parabéns Gabriel!"
236
+ elif result == -1: st.session_state.message = "A IA GANHOU. Mais sorte na próxima."
237
+ else: st.session_state.message = "EMPATE."
238
+ st.rerun()
239
+
240
+ # VEZ DO HUMANO (Player 1)
241
+ if st.session_state.player == 1:
242
+ valid_moves = game.get_valid_moves(st.session_state.board, 1)
243
+
244
+ if not valid_moves:
245
+ # Se não tem movimentos e não deu game over acima, algo estranho aconteceu, mas tratamos como derrota
246
+ st.session_state.game_over = True
247
+ st.session_state.message = "Sem movimentos válidos. Você perdeu."
248
+ st.rerun()
249
+
250
+ # Criar lista de strings para o Selectbox
251
+ move_options = [str(m) for m in valid_moves]
252
+ selected_move_str = st.selectbox("Escolha sua jogada:", move_options)
253
+
254
+ if st.button("Jogar"):
255
+ # Encontrar o movimento original baseado na string
256
+ move_idx = move_options.index(selected_move_str)
257
+ move = valid_moves[move_idx]
258
+
259
+ # Aplicar movimento
260
+ st.session_state.board = game.apply_move(st.session_state.board, move)
261
+ st.session_state.player = -1 # Passa a vez para IA
262
+ st.session_state.message = "A IA está pensando..."
263
+ st.rerun()
264
+
265
+ # VEZ DA IA (Player -1)
266
+ else:
267
+ with st.spinner("A AlphaCheckerZero está pensando..."):
268
+ # Pequeno delay para a interface atualizar e mostrar a mensagem
269
+ time.sleep(0.5)
270
+
271
+ valid_moves, policy = mcts.run(np.copy(st.session_state.board), -1)
272
+
273
+ if not valid_moves:
274
+ st.session_state.game_over = True
275
+ st.session_state.message = "A IA não tem movimentos. Você venceu!"
276
+ st.rerun()
277
+
278
+ move = valid_moves[np.argmax(policy)]
279
+
280
+ st.session_state.board = game.apply_move(st.session_state.board, move)
281
+ st.session_state.player = 1 # Devolve a vez para Humano
282
+ st.session_state.message = f"IA jogou: {move}. Sua vez!"
283
+ st.rerun()
284
+
285
+ else:
286
+ st.success(st.session_state.message)