File size: 1,747 Bytes
620876b
0907a69
620876b
9b0851e
5395b71
0c4536a
3d973f8
 
4ee699d
07c5183
3d973f8
dbb3e77
3d973f8
c2da793
0907a69
620876b
0907a69
620876b
 
5376fe6
9b0851e
009c52e
 
 
3d973f8
 
 
 
 
009c52e
3d973f8
 
 
 
009c52e
3d973f8
 
 
 
009c52e
3d973f8
009c52e
3d973f8
 
 
 
 
009c52e
69db3b2
3d973f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import chess
from datasets import load_dataset
import time
from chess_search.sparse.bitmap import BitmapIndex
from chess_search.utils import position_to_tokens, replay_moves

dset = load_dataset("Lichess/chess-puzzles", split="train")
puzzle_lookup = {row["PuzzleId"]: i for i, row in enumerate(dset)}
idx = BitmapIndex.load("index")

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

@app.get("/")
def read_root(request: Request):
    return templates.TemplateResponse(request=request, name="index.html")

@app.post("/search")
async def search(data: dict):
    start = time.time()
    board = chess.Board(data["fen"])
    tokens = position_to_tokens(board)
    matches = idx.query(tokens)
    resolved = idx.resolve(matches)

    seen_puzzles = {}
    for puzzle_id, move_idx in resolved:
        if puzzle_id not in seen_puzzles:
            seen_puzzles[puzzle_id] = move_idx

    results = []
    for puzzle_id, move_idx in seen_puzzles.items():
        row = dset[puzzle_lookup[puzzle_id]]
        boards = replay_moves(row["FEN"], row["Moves"].split())
        matched_board = boards[move_idx]
        results.append({
            "PuzzleId": row["PuzzleId"],
            "FEN": matched_board.fen(),
            "Moves": row["Moves"],
            "Rating": row["Rating"],
            "Popularity": row["Popularity"],
            "Themes": row["Themes"],
            "MatchedMove": move_idx,
        })

    elapsed_ms = (time.time() - start) * 1000
    return {"count": len(results), "results": results, "time_ms": elapsed_ms}