christopher commited on
Commit
69db3b2
·
1 Parent(s): 9612e8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -26
app.py CHANGED
@@ -48,33 +48,60 @@ templates = Jinja2Templates(directory="templates")
48
  def read_root(request: Request):
49
  return templates.TemplateResponse("index.html", {"request": request})
50
 
51
- @app.post("/search")
52
- async def search(data: dict):
53
- start = time.time()
54
- board = chess.Board(data['fen'])
55
- query_tokens = board_to_tokens(board)
56
- matches = query_positions(index, metadata, query_tokens)
57
 
58
- seen_puzzles = {}
59
- for pos_id, (puzzle_row, move_idx) in matches:
60
- if puzzle_row not in seen_puzzles:
61
- seen_puzzles[puzzle_row] = (pos_id, move_idx)
62
 
63
- results = []
64
- for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
65
- row = dset[puzzle_row]
66
- positions = get_puzzle_positions(row['FEN'], row['Moves'])
67
- matched_board = positions[move_idx]
68
 
69
- results.append({
70
- "PuzzleId": row['PuzzleId'],
71
- "FEN": matched_board.fen(),
72
- "Moves": row['Moves'],
73
- "Rating": row['Rating'],
74
- "Popularity": row['Popularity'],
75
- "Themes": row['Themes'],
76
- "MatchedMove": move_idx
77
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- elapsed_ms = (time.time() - start) * 1000
80
- return {"count": len(results), "results": results, "time_ms": elapsed_ms}
 
48
  def read_root(request: Request):
49
  return templates.TemplateResponse("index.html", {"request": request})
50
 
51
+ # @app.post("/search")
52
+ # async def search(data: dict):
53
+ # start = time.time()
54
+ # board = chess.Board(data['fen'])
55
+ # query_tokens = board_to_tokens(board)
56
+ # matches = query_positions(index, metadata, query_tokens)
57
 
58
+ # seen_puzzles = {}
59
+ # for pos_id, (puzzle_row, move_idx) in matches:
60
+ # if puzzle_row not in seen_puzzles:
61
+ # seen_puzzles[puzzle_row] = (pos_id, move_idx)
62
 
63
+ # results = []
64
+ # for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
65
+ # row = dset[puzzle_row]
66
+ # positions = get_puzzle_positions(row['FEN'], row['Moves'])
67
+ # matched_board = positions[move_idx]
68
 
69
+ # results.append({
70
+ # "PuzzleId": row['PuzzleId'],
71
+ # "FEN": matched_board.fen(),
72
+ # "Moves": row['Moves'],
73
+ # "Rating": row['Rating'],
74
+ # "Popularity": row['Popularity'],
75
+ # "Themes": row['Themes'],
76
+ # "MatchedMove": move_idx
77
+ # })
78
+
79
+ # elapsed_ms = (time.time() - start) * 1000
80
+ # return {"count": len(results), "results": results, "time_ms": elapsed_ms}
81
+
82
+ @app.post("/search")
83
+ async def search(data: dict):
84
+ def generate():
85
+ board = chess.Board(data['fen'])
86
+ query_tokens = board_to_tokens(board)
87
+ matches = query_positions(index, metadata, query_tokens)
88
+
89
+ seen_puzzles = set()
90
+ for pos_id, (puzzle_row, move_idx) in matches:
91
+ if puzzle_row in seen_puzzles: continue
92
+ seen_puzzles.add(puzzle_row)
93
+
94
+ row = dset[puzzle_row]
95
+ positions = get_puzzle_positions(row['FEN'], row['Moves'])
96
+ matched_board = positions[move_idx]
97
+
98
+ result = {"PuzzleId": row['PuzzleId'],
99
+ "FEN": matched_board.fen(),
100
+ "Moves": row['Moves'],
101
+ "Rating": row['Rating'],
102
+ "Popularity": row['Popularity'],
103
+ "Themes": row['Themes'],
104
+ "MatchedMove": move_idx}
105
+ yield json.dumps(result) + "\n"
106
 
107
+ return StreamingResponse(generate(), media_type="application/x-ndjson")