| | import os |
| | import io |
| | import torch |
| | import requests |
| | import chess.pgn |
| | import numpy as np |
| | from data_objects.game import Game |
| | from encoder.model import Encoder |
| |
|
| | |
| | def generate_alternative_pgns(game): |
| | if not game: |
| | print("couldn't read game") |
| | return [], None, None |
| | |
| | |
| | board = game.board() |
| | moves = list(game.mainline_moves()) |
| | |
| | |
| | for move in moves: |
| | board.push(move) |
| | |
| | |
| | legal_moves = list(board.legal_moves) |
| | |
| | |
| | result_pgns = [] |
| | move_sans = [] |
| | |
| | for legal_move in legal_moves: |
| | |
| | new_game = chess.pgn.Game() |
| | |
| | |
| | for key in game.headers: |
| | new_game.headers[key] = game.headers[key] |
| | |
| | |
| | if "Result" in new_game.headers: |
| | new_game.headers["Result"] = "*" |
| | |
| | |
| | node = new_game |
| | for move in moves: |
| | node = node.add_variation(move) |
| | |
| | |
| | node = node.add_variation(legal_move) |
| | |
| | |
| | new_pgn = io.StringIO() |
| | exporter = chess.pgn.FileExporter(new_pgn) |
| | new_game.accept(exporter) |
| | |
| | |
| | result_pgns.append(new_pgn.getvalue()) |
| | move_sans.append(board.san(legal_move)) |
| |
|
| | return result_pgns, move_sans |
| |
|
| | def process_game(game, prediction_mode = False): |
| | def create_position_planes(board: chess.Board, positions_seen: set, cur_player: chess.Color) -> np.ndarray: |
| |
|
| | def bb_to_plane(bb: int, player: chess.Color) -> np.ndarray: |
| | binary = format(bb, '064b') |
| | h_flipped = np.fliplr(np.array([int(binary[i]) for i in range(64)], dtype=np.float32).reshape(8, 8)) |
| | if player: |
| | return h_flipped |
| | else: |
| | return np.flip(h_flipped) |
| | |
| | planes = np.zeros((13, 8, 8), dtype=np.float32) |
| | |
| | piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING] |
| | |
| | |
| | for i, piece_type in enumerate(piece_types): |
| | bb = board.pieces_mask(piece_type, chess.WHITE) |
| | planes[i] = bb_to_plane(bb, cur_player) |
| | |
| | |
| | for i, piece_type in enumerate(piece_types): |
| | bb = board.pieces_mask(piece_type, chess.BLACK) |
| | planes[i + 6] = bb_to_plane(bb, cur_player) |
| | |
| | |
| | current_position = board.fen().split(' ')[0] |
| | if list(positions_seen).count(current_position) > 1: |
| | planes[12] = 1.0 |
| | |
| | return planes |
| |
|
| | board = chess.Board() |
| | positions_seen = set() |
| | positions_seen.add(board.fen().split(' ')[0]) |
| | |
| | white_moves = [] |
| | black_moves = [] |
| | |
| | node = game |
| | while node.next(): |
| | node = node.next() |
| | move = node.move |
| | assert(move is not None) |
| | cur_player = board.turn |
| |
|
| | current_planes = create_position_planes(board, positions_seen, cur_player) |
| | |
| | board.push(move) |
| | |
| | positions_seen.add(board.fen().split(' ')[0]) |
| | |
| | next_planes = create_position_planes(board, positions_seen, cur_player) |
| | assert(not (current_planes==next_planes).all()) |
| | |
| | |
| | move_planes = np.zeros((34, 8, 8), dtype=np.float32) |
| | |
| | |
| | move_planes[0:13] = current_planes |
| | |
| | |
| | move_planes[13:26] = next_planes |
| | |
| | |
| | move_planes[26] = float(board.has_queenside_castling_rights(chess.WHITE)) |
| | move_planes[27] = float(board.has_kingside_castling_rights(chess.WHITE)) |
| | move_planes[28] = float(board.has_queenside_castling_rights(chess.BLACK)) |
| | move_planes[29] = float(board.has_kingside_castling_rights(chess.BLACK)) |
| | |
| | |
| | move_planes[30] = 1 if board.turn is chess.WHITE else 0 |
| | |
| | |
| | move_planes[31] = board.halfmove_clock / 100.0 |
| | |
| | |
| | |
| | clock_info = node.comment.strip('{}[] ').split()[1] if node.comment else "0:00:30" |
| | try: |
| | minutes, seconds = map(int, clock_info.split(':')[1:]) |
| | total_seconds = minutes * 60 + seconds |
| | move_planes[32] = min(1.0, total_seconds / 180.0) |
| | except: |
| | move_planes[32] = 0.5 |
| | |
| | |
| | move_planes[33] = 1.0 |
| | |
| | if board.turn: |
| | black_moves.append(move_planes) |
| | else: |
| | white_moves.append(move_planes) |
| | |
| | if (not prediction_mode) and (len(white_moves) < 10 or len(black_moves) < 10): |
| | return None |
| | |
| | white_array = np.stack(white_moves, axis=0) |
| | black_array = [] if not black_moves else np.stack(black_moves, axis=0) |
| | |
| | return white_array, black_array |
| |
|
| |
|
| | class EndpointHandler(): |
| | def __init__(self, model_dir): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | checkpoint = torch.load(os.path.join(model_dir, "6_3.pt"), self.device, weights_only=True) |
| | self.model = Encoder(self.device) |
| | state_dict = checkpoint['model_state'] |
| | self.model.load_state_dict(state_dict) |
| | self.model = self.model.to(self.device) |
| | self.model.eval() |
| | self.d = { |
| | 0: self.say_hi, |
| | 1: self.create_user_embedding, |
| | 2: self.ai_move |
| | } |
| |
|
| | def say_hi(self, _data): |
| | print('entering test endpoint') |
| |
|
| | print('exiting test endpoint') |
| | return {"reply": "hello from inference api!!"} |
| | |
| | def create_user_embedding(self, data): |
| | print('entering create_username endpoint') |
| | username = data["username"] |
| | pgn_content = data["pgn_content"] |
| | games_per_player = data["games_per_player"] |
| |
|
| | l = [] |
| | while True: |
| | game = chess.pgn.read_game(io.StringIO(pgn_content)) |
| | if game is None: |
| | print("breaking main loop") |
| | break |
| | white = game.headers.get("White") |
| | black = game.headers.get("Black") |
| | if white == username: |
| | color = "white" |
| | elif black == username: |
| | color = "black" |
| | else: |
| | raise Exception |
| | try: |
| | arrs = process_game(game) |
| | except: |
| | print("skipped") |
| | continue |
| | if arrs is None: |
| | print("skipped") |
| | continue |
| | if color == "white": |
| | l.append(arrs[0]) |
| | else: |
| | l.append(arrs[1]) |
| | if not l: return None |
| |
|
| | inputs = np.array([Game(g).random_partial() for g in l[:games_per_player]]) |
| | num_games = min(len(l), games_per_player) |
| |
|
| | tensor = torch.tensor(inputs).float().to(self.device) |
| | with torch.no_grad(): |
| | embeds = self.model(tensor) |
| | embeds = embeds.view((1, num_games, -1)).to(self.device) |
| | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) |
| | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) |
| | centroids_incl = centroids_incl.cpu().squeeze(1) |
| | final_embeds = centroids_incl[0].numpy().tolist() |
| |
|
| | print('exiting create_username endpoint') |
| | return {"reply": final_embeds} |
| | |
| | def ai_move(self, data): |
| | print('entering ai_move endpoint') |
| | pgn_string = data["pgn_string"] |
| | color = data["color"] |
| | player_centroid = data["player_centroid"] |
| |
|
| | game = chess.pgn.read_game(io.StringIO(pgn_string)) |
| | alternative_pgns, move_sans = generate_alternative_pgns(game) |
| | game = chess.pgn.read_game(io.StringIO(pgn_string)) |
| |
|
| | inputs = [] |
| | for alt_pgn in alternative_pgns: |
| | game_tensors = process_game(chess.pgn.read_game(io.StringIO(alt_pgn)), True) |
| | game_tensor = game_tensors[0] if color == "white" else game_tensors[1] |
| | inputs.append(game_tensor) |
| |
|
| | tensor = torch.tensor(np.array(inputs)).float().to(self.device) |
| | with torch.no_grad(): |
| | embed = self.model(tensor) |
| | embed = embed / torch.norm(embed) |
| |
|
| | arr = embed.cpu().numpy() |
| | similarities = [np.dot(np.array(player_centroid), embed) for embed in arr] |
| | result = move_sans[np.argmax(similarities)] |
| |
|
| | ordered_moves = np.argsort(similarities).tolist()[::-1] |
| | try: |
| | board = game.board() |
| | moves = list(game.mainline_moves()) |
| | |
| | for move in moves: |
| | board.push(move) |
| | response = requests.post("http://16.16.211.183/stockfish_eval", json={"fen": board.fen()}) |
| |
|
| | if response.status_code == 400: |
| | print(response.text) |
| | print('exiting ai_move endpoint status code before move') |
| | return {"reply": result} |
| | best_eval = response.json()["value"] |
| | best_move = response.json()["best"] |
| | best_move = chess.Move.from_uci(best_move) |
| | best_move = board.san(best_move) |
| |
|
| | for move in ordered_moves: |
| | test_board = board.copy() |
| | test_board.push(board.parse_san(move_sans[move])) |
| | response = requests.post("http://16.16.211.183/stockfish_eval", json={"fen": test_board.fen()}) |
| | if response.status_code == 500: |
| | print('exiting ai_move endpoint status code after move') |
| | return {"reply": best_move} |
| | eval = response.json()["value"] |
| | if (color == "white" and (best_eval - eval < 120)) or (color == "black" and (best_eval - eval > -120)): |
| | print('exiting ai_move endpoint nice found!') |
| | return {"reply": move_sans[move]} |
| | print('exiting ai_move endpoint all moves are shit!') |
| | return {"reply": best_move} |
| |
|
| | except Exception as e: |
| | print('error sending to lichess', e) |
| | print('exiting ai_move endpoint due to exception') |
| | return {"reply": result} |
| | |
| | def __call__(self, data): |
| | data = data.get("inputs", data) |
| | return self.d[data["endpoint_num"]](data) |