File size: 3,766 Bytes
620876b
c007670
0907a69
620876b
9b0851e
5395b71
46bf00b
0c4536a
218d7fb
c007670
5395b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2da793
9157c4e
4ee699d
 
 
 
 
 
 
 
 
 
 
07c5183
d9f1d22
c2da793
0907a69
620876b
0907a69
620876b
 
 
9b0851e
009c52e
 
 
 
 
 
9612e8f
009c52e
 
 
 
9612e8f
009c52e
 
 
 
 
c007670
009c52e
 
 
 
 
 
 
 
 
69db3b2
009c52e
 
69db3b2
009c52e
 
 
 
 
 
69db3b2
009c52e
 
 
 
69db3b2
009c52e
 
 
69db3b2
009c52e
 
 
 
 
 
 
 
5395b71
009c52e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import chess
from datasets import load_dataset
import pickle
import time
from pyroaring import BitMap
import json

def board_to_tokens(board):
    return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
    
def get_puzzle_positions(fen, moves_uci):
    positions = []
    
    board = chess.Board(fen)
    
    board.push_uci(moves_uci.split()[0])
    positions.append(board.copy())
    
    for move_uci in moves_uci.split()[1:]:
        board.push_uci(move_uci)
        positions.append(board.copy())
    
    return positions

def load_index(path='puzzles_position_index.pkl'):
    with open(path, 'rb') as f: data = pickle.load(f)
    return data['index'], data['metadata']

def query_positions(index, metadata, query_tokens):
    result = index[query_tokens[0]].copy() if query_tokens[0] in index else BitMap()
    for token in query_tokens[1:]:
        if token in index: result &= index[token]
        else: return BitMap()
            
    return [(pos_id, metadata[pos_id]) for pos_id in result]

dset = load_dataset("Lichess/chess-puzzles", split="train")
index, metadata = load_index()
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

@app.get("/")
def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/search")
async def search(data: dict):
    start = time.time()
    board = chess.Board(data['fen'])
    query_tokens = board_to_tokens(board)
    matches = query_positions(index, metadata, query_tokens)
    
    seen_puzzles = {}
    for pos_id, (puzzle_row, move_idx) in matches:
        if puzzle_row not in seen_puzzles:
            seen_puzzles[puzzle_row] = (pos_id, move_idx)
    
    results = []
    for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
        row = dset[puzzle_row]
        positions = get_puzzle_positions(row['FEN'], row['Moves'])
        matched_board = positions[move_idx]
        
        results.append({
            "PuzzleId": row['PuzzleId'],
            "FEN": matched_board.fen(),
            "Moves": row['Moves'],
            "Rating": row['Rating'],
            "Popularity": row['Popularity'],
            "Themes": row['Themes'],
            "MatchedMove": move_idx
        })
    
    elapsed_ms = (time.time() - start) * 1000
    return {"count": len(results), "results": results, "time_ms": elapsed_ms}

# @app.post("/search")
# async def search(data: dict):
#     def generate():
#         board = chess.Board(data['fen'])
#         query_tokens = board_to_tokens(board)
#         matches = query_positions(index, metadata, query_tokens)
        
#         seen_puzzles = set()
#         for pos_id, (puzzle_row, move_idx) in matches:
#             if puzzle_row in seen_puzzles: continue
#             seen_puzzles.add(puzzle_row)
            
#             row = dset[puzzle_row]
#             positions = get_puzzle_positions(row['FEN'], row['Moves'])
#             matched_board = positions[move_idx]
            
#             result = {"PuzzleId": row['PuzzleId'],
#                       "FEN": matched_board.fen(),
#                       "Moves": row['Moves'],
#                       "Rating": row['Rating'],
#                       "Popularity": row['Popularity'],
#                       "Themes": row['Themes'],
#                       "MatchedMove": move_idx}
#             yield json.dumps(result) + "\n"
    
#     return StreamingResponse(generate(), media_type="application/x-ndjson")