Spaces:
Running
Running
File size: 3,766 Bytes
620876b c007670 0907a69 620876b 9b0851e 5395b71 46bf00b 0c4536a 218d7fb c007670 5395b71 c2da793 9157c4e 4ee699d 07c5183 d9f1d22 c2da793 0907a69 620876b 0907a69 620876b 9b0851e 009c52e 9612e8f 009c52e 9612e8f 009c52e c007670 009c52e 69db3b2 009c52e 69db3b2 009c52e 69db3b2 009c52e 69db3b2 009c52e 69db3b2 009c52e 5395b71 009c52e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import chess
from datasets import load_dataset
import pickle
import time
from pyroaring import BitMap
import json
def board_to_tokens(board):
return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
def get_puzzle_positions(fen, moves_uci):
positions = []
board = chess.Board(fen)
board.push_uci(moves_uci.split()[0])
positions.append(board.copy())
for move_uci in moves_uci.split()[1:]:
board.push_uci(move_uci)
positions.append(board.copy())
return positions
def load_index(path='puzzles_position_index.pkl'):
with open(path, 'rb') as f: data = pickle.load(f)
return data['index'], data['metadata']
def query_positions(index, metadata, query_tokens):
result = index[query_tokens[0]].copy() if query_tokens[0] in index else BitMap()
for token in query_tokens[1:]:
if token in index: result &= index[token]
else: return BitMap()
return [(pos_id, metadata[pos_id]) for pos_id in result]
dset = load_dataset("Lichess/chess-puzzles", split="train")
index, metadata = load_index()
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/search")
async def search(data: dict):
start = time.time()
board = chess.Board(data['fen'])
query_tokens = board_to_tokens(board)
matches = query_positions(index, metadata, query_tokens)
seen_puzzles = {}
for pos_id, (puzzle_row, move_idx) in matches:
if puzzle_row not in seen_puzzles:
seen_puzzles[puzzle_row] = (pos_id, move_idx)
results = []
for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
row = dset[puzzle_row]
positions = get_puzzle_positions(row['FEN'], row['Moves'])
matched_board = positions[move_idx]
results.append({
"PuzzleId": row['PuzzleId'],
"FEN": matched_board.fen(),
"Moves": row['Moves'],
"Rating": row['Rating'],
"Popularity": row['Popularity'],
"Themes": row['Themes'],
"MatchedMove": move_idx
})
elapsed_ms = (time.time() - start) * 1000
return {"count": len(results), "results": results, "time_ms": elapsed_ms}
# @app.post("/search")
# async def search(data: dict):
# def generate():
# board = chess.Board(data['fen'])
# query_tokens = board_to_tokens(board)
# matches = query_positions(index, metadata, query_tokens)
# seen_puzzles = set()
# for pos_id, (puzzle_row, move_idx) in matches:
# if puzzle_row in seen_puzzles: continue
# seen_puzzles.add(puzzle_row)
# row = dset[puzzle_row]
# positions = get_puzzle_positions(row['FEN'], row['Moves'])
# matched_board = positions[move_idx]
# result = {"PuzzleId": row['PuzzleId'],
# "FEN": matched_board.fen(),
# "Moves": row['Moves'],
# "Rating": row['Rating'],
# "Popularity": row['Popularity'],
# "Themes": row['Themes'],
# "MatchedMove": move_idx}
# yield json.dumps(result) + "\n"
# return StreamingResponse(generate(), media_type="application/x-ndjson") |