christopher commited on
Commit
c007670
·
1 Parent(s): 7baac56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -27
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
  import chess
@@ -7,6 +7,7 @@ from datasets import load_dataset
7
  import pickle
8
  import time
9
  from pyroaring import BitMap
 
10
 
11
  def board_to_tokens(board):
12
  return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
@@ -49,33 +50,56 @@ def read_root(request: Request):
49
 
50
 
51
 
52
- @app.post("/search")
53
- async def search(data: dict):
54
- start = time.time()
55
- board = chess.Board(data['fen'])
56
- query_tokens = board_to_tokens(board)
57
- matches = query_positions(index, metadata, query_tokens)
58
 
59
- seen_puzzles = {}
60
- for pos_id, (puzzle_row, move_idx) in matches:
61
- if puzzle_row not in seen_puzzles:
62
- seen_puzzles[puzzle_row] = (pos_id, move_idx)
63
 
64
- results = []
65
- for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
66
- row = dset[puzzle_row]
67
- positions = get_puzzle_positions(row['FEN'], row['Moves'])
68
- matched_board = positions[move_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- results.append({
71
- "PuzzleId": row['PuzzleId'],
72
- "FEN": matched_board.fen(),
73
- "Moves": row['Moves'],
74
- "Rating": row['Rating'],
75
- "Popularity": row['Popularity'],
76
- "Themes": row['Themes'],
77
- "MatchedMove": move_idx
78
- })
 
 
 
79
 
80
- elapsed_ms = (time.time() - start) * 1000
81
- return {"count": len(results), "results": results, "time_ms": elapsed_ms}
 
1
  from fastapi import FastAPI, Request
2
+ from fastapi.responses import HTMLResponse, StreamingResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
  import chess
 
7
  import pickle
8
  import time
9
  from pyroaring import BitMap
10
+ import json
11
 
12
  def board_to_tokens(board):
13
  return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
 
50
 
51
 
52
 
53
+ # @app.post("/search")
54
+ # async def search(data: dict):
55
+ # start = time.time()
56
+ # board = chess.Board(data['fen'])
57
+ # query_tokens = board_to_tokens(board)
58
+ # matches = query_positions(index, metadata, query_tokens)
59
 
60
+ # seen_puzzles = {}
61
+ # for pos_id, (puzzle_row, move_idx) in matches:
62
+ # if puzzle_row not in seen_puzzles:
63
+ # seen_puzzles[puzzle_row] = (pos_id, move_idx)
64
 
65
+ # results = []
66
+ # for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
67
+ # row = dset[puzzle_row]
68
+ # positions = get_puzzle_positions(row['FEN'], row['Moves'])
69
+ # matched_board = positions[move_idx]
70
+
71
+ # results.append({
72
+ # "PuzzleId": row['PuzzleId'],
73
+ # "FEN": matched_board.fen(),
74
+ # "Moves": row['Moves'],
75
+ # "Rating": row['Rating'],
76
+ # "Popularity": row['Popularity'],
77
+ # "Themes": row['Themes'],
78
+ # "MatchedMove": move_idx
79
+ # })
80
+
81
+ # elapsed_ms = (time.time() - start) * 1000
82
+ # return {"count": len(results), "results": results, "time_ms": elapsed_ms}
83
+
84
+
85
+ @app.post("/search")
86
+ async def search(data: dict):
87
+ async def generate():
88
+ board = chess.Board(data['fen'])
89
+ query_tokens = board_to_tokens(board)
90
+ matches = query_positions(index, metadata, query_tokens)
91
 
92
+ seen_puzzles = {}
93
+ for pos_id, (puzzle_row, move_idx) in matches:
94
+ if puzzle_row not in seen_puzzles:
95
+ seen_puzzles[puzzle_row] = (pos_id, move_idx)
96
+
97
+ for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
98
+ row = dset[puzzle_row]
99
+ positions = get_puzzle_positions(row['FEN'], row['Moves'])
100
+ matched_board = positions[move_idx]
101
+
102
+ result = {"PuzzleId": row['PuzzleId'], "FEN": matched_board.fen(), ...}
103
+ yield json.dumps(result) + "\n"
104
 
105
+ return StreamingResponse(generate(), media_type="application/x-ndjson")