Spaces:
Running
Running
File size: 1,747 Bytes
620876b 0907a69 620876b 9b0851e 5395b71 0c4536a 3d973f8 4ee699d 07c5183 3d973f8 dbb3e77 3d973f8 c2da793 0907a69 620876b 0907a69 620876b 5376fe6 9b0851e 009c52e 3d973f8 009c52e 3d973f8 009c52e 3d973f8 009c52e 3d973f8 009c52e 3d973f8 009c52e 69db3b2 3d973f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import chess
from datasets import load_dataset
import time
from chess_search.sparse.bitmap import BitmapIndex
from chess_search.utils import position_to_tokens, replay_moves
dset = load_dataset("Lichess/chess-puzzles", split="train")
puzzle_lookup = {row["PuzzleId"]: i for i, row in enumerate(dset)}
idx = BitmapIndex.load("index")
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse(request=request, name="index.html")
@app.post("/search")
async def search(data: dict):
start = time.time()
board = chess.Board(data["fen"])
tokens = position_to_tokens(board)
matches = idx.query(tokens)
resolved = idx.resolve(matches)
seen_puzzles = {}
for puzzle_id, move_idx in resolved:
if puzzle_id not in seen_puzzles:
seen_puzzles[puzzle_id] = move_idx
results = []
for puzzle_id, move_idx in seen_puzzles.items():
row = dset[puzzle_lookup[puzzle_id]]
boards = replay_moves(row["FEN"], row["Moves"].split())
matched_board = boards[move_idx]
results.append({
"PuzzleId": row["PuzzleId"],
"FEN": matched_board.fen(),
"Moves": row["Moves"],
"Rating": row["Rating"],
"Popularity": row["Popularity"],
"Themes": row["Themes"],
"MatchedMove": move_idx,
})
elapsed_ms = (time.time() - start) * 1000
return {"count": len(results), "results": results, "time_ms": elapsed_ms} |