Spaces:
Running
Running
Commit
·
5395b71
1
Parent(s):
50d0d7d
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,27 @@ from fastapi.responses import HTMLResponse
|
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
from fastapi.templating import Jinja2Templates
|
| 5 |
import chess
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
app = FastAPI()
|
| 8 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
@@ -12,19 +33,29 @@ templates = Jinja2Templates(directory="templates")
|
|
| 12 |
def read_root(request: Request):
|
| 13 |
return templates.TemplateResponse("index.html", {"request": request})
|
| 14 |
|
|
|
|
| 15 |
@app.post("/search")
|
| 16 |
async def search(data: dict):
|
| 17 |
board = chess.Board(data['fen'])
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
from fastapi.templating import Jinja2Templates
|
| 5 |
import chess
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
|
| 8 |
+
dset = load_dataset("Lichess/chess-puzzles", split="train")
|
| 9 |
+
index, metadata = load_index('chess_index.pkl')
|
| 10 |
+
|
| 11 |
+
def board_to_tokens(board):
|
| 12 |
+
return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
|
| 13 |
+
|
| 14 |
+
def get_puzzle_positions(fen, moves_uci):
|
| 15 |
+
positions = []
|
| 16 |
+
|
| 17 |
+
board = chess.Board(fen)
|
| 18 |
+
|
| 19 |
+
board.push_uci(moves_uci.split()[0])
|
| 20 |
+
positions.append(board.copy())
|
| 21 |
+
|
| 22 |
+
for move_uci in moves_uci.split()[1:]:
|
| 23 |
+
board.push_uci(move_uci)
|
| 24 |
+
positions.append(board.copy())
|
| 25 |
+
|
| 26 |
+
return positions
|
| 27 |
|
| 28 |
app = FastAPI()
|
| 29 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
| 33 |
def read_root(request: Request):
|
| 34 |
return templates.TemplateResponse("index.html", {"request": request})
|
| 35 |
|
| 36 |
+
|
| 37 |
@app.post("/search")
|
| 38 |
async def search(data: dict):
|
| 39 |
board = chess.Board(data['fen'])
|
| 40 |
+
query_tokens = board_to_tokens(board)
|
| 41 |
+
matches = query_positions(index, metadata, query_tokens)
|
| 42 |
+
|
| 43 |
+
results = []
|
| 44 |
+
for pos_id, (puzzle_row, move_idx) in matches[:100]:
|
| 45 |
+
row = dset[puzzle_row]
|
| 46 |
+
positions = get_puzzle_positions(row['FEN'], row['Moves'])
|
| 47 |
+
matched_board = positions[move_idx]
|
| 48 |
+
|
| 49 |
+
results.append({
|
| 50 |
+
"PuzzleId": row['PuzzleId'],
|
| 51 |
+
"FEN": matched_board.fen(),
|
| 52 |
+
"Moves": row['Moves'],
|
| 53 |
+
"Rating": row['Rating'],
|
| 54 |
+
"Popularity": row['Popularity'],
|
| 55 |
+
"Themes": row['Themes'].split(),
|
| 56 |
+
"PuzzleUrl": f"https://lichess.org/training/{row['PuzzleId']}",
|
| 57 |
+
"GameUrl": row['GameUrl'],
|
| 58 |
+
"MatchedMove": move_idx
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
return {"count": len(matches), "results": results}
|