davidmezzetti commited on
Commit
4d67ead
·
verified ·
1 Parent(s): 4e23bb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -13
app.py CHANGED
@@ -38,7 +38,7 @@ class Stats:
38
  self.names = self.loadnames()
39
 
40
  # Build index
41
- self.vectors, self.data, self.embeddings = self.index()
42
 
43
  def loadcolumns(self):
44
  """
@@ -122,12 +122,12 @@ class Stats:
122
  # Build data dictionary
123
  vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
124
  data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
 
125
 
126
  embeddings = Embeddings({"transform": Stats.transform})
127
-
128
  embeddings.index((uid, vectors[uid], None) for uid in vectors)
129
 
130
- return vectors, data, embeddings
131
 
132
  def metrics(self, name):
133
  """
@@ -155,13 +155,14 @@ class Stats:
155
 
156
  return range(1871, datetime.datetime.today().year), 1950, None
157
 
158
- def search(self, name=None, year=None, row=None, limit=10):
159
  """
160
  Runs an embeddings search. This method takes either a player-year or stats row as input.
161
 
162
  Args:
163
  name: player name to search
164
  year: year to search
 
165
  row: row of stats to search
166
  limit: max results to return
167
 
@@ -179,13 +180,17 @@ class Stats:
179
 
180
  results, ids = [], set()
181
  if query is not None:
182
- for uid, _ in self.embeddings.search(query, limit * 5):
 
183
  # Only add unique players
184
  if uid[4:] not in ids:
185
  result = self.data[uid].copy()
186
- result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
187
- results.append(result)
188
- ids.add(uid[4:])
 
 
 
189
 
190
  if len(ids) >= limit:
191
  break
@@ -467,6 +472,9 @@ class Application:
467
  # Player name
468
  name = self.name(stats.names, params.get("name"))
469
 
 
 
 
470
  # Player metrics
471
  active, best, metrics = stats.metrics(name)
472
 
@@ -478,7 +486,7 @@ class Application:
478
  self.chart(category, metrics)
479
 
480
  # Run search
481
- results = stats.search(name, year)
482
 
483
  # Display results
484
  self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
@@ -495,6 +503,10 @@ class Application:
495
  st.markdown("Find players with similar statistics.")
496
 
497
  stats, category = None, self.category("Batting", "searchcategory")
 
 
 
 
498
  with st.form("search"):
499
  if category == "Batting":
500
  stats, columns = self.batting, self.batting.columns[:-6]
@@ -507,7 +519,7 @@ class Application:
507
  submitted = st.form_submit_button("Search")
508
  if submitted:
509
  # Run search
510
- results = stats.search(row=inputs.to_dict(orient="records")[0])
511
 
512
  # Display table
513
  self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
@@ -521,15 +533,16 @@ class Application:
521
  """
522
 
523
  # Get parameters
524
- params = {x: st.query_params.get(x) for x in ["category", "name", "year"]}
525
 
526
  # Sync parameters with session state
527
- if all(x in st.session_state for x in ["category", "name", "year"]):
528
  # Copy session year if category and name are unchanged
529
  params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
530
 
531
  # Copy category and name from session state
532
  params["category"] = st.session_state["category"]
 
533
  params["name"] = st.session_state["name"]
534
 
535
  return params
@@ -555,6 +568,20 @@ class Application:
555
  # Radio box component
556
  return st.radio("Stat", categories, index=default, horizontal=True, key=key)
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  def name(self, names, name):
559
  """
560
  Builds name input widget.
@@ -670,4 +697,4 @@ if __name__ == "__main__":
670
 
671
  # Create and run application
672
  app = create()
673
- app.run()
 
38
  self.names = self.loadnames()
39
 
40
  # Build index
41
+ self.vectors, self.data, self.maxyear, self.embeddings = self.index()
42
 
43
  def loadcolumns(self):
44
  """
 
122
  # Build data dictionary
123
  vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
124
  data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
125
+ maxyear = max(row["yearID"] for _, row in self.stats.iterrows())
126
 
127
  embeddings = Embeddings({"transform": Stats.transform})
 
128
  embeddings.index((uid, vectors[uid], None) for uid in vectors)
129
 
130
+ return vectors, data, maxyear, embeddings
131
 
132
  def metrics(self, name):
133
  """
 
155
 
156
  return range(1871, datetime.datetime.today().year), 1950, None
157
 
158
+ def search(self, name=None, year=None, window=None, row=None, limit=10):
159
  """
160
  Runs an embeddings search. This method takes either a player-year or stats row as input.
161
 
162
  Args:
163
  name: player name to search
164
  year: year to search
165
+ window: limit to window recent seasons
166
  row: row of stats to search
167
  limit: max results to return
168
 
 
180
 
181
  results, ids = [], set()
182
  if query is not None:
183
+ candidates = limit * 100 if window else limit * 5
184
+ for uid, _ in self.embeddings.search(query, candidates):
185
  # Only add unique players
186
  if uid[4:] not in ids:
187
  result = self.data[uid].copy()
188
+
189
+ # Add first player if this is a player comparison. Limit results to window, if necessary
190
+ if (not ids and not row) or not window or result["yearID"] > self.maxyear - window:
191
+ result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
192
+ results.append(result)
193
+ ids.add(uid[4:])
194
 
195
  if len(ids) >= limit:
196
  break
 
472
  # Player name
473
  name = self.name(stats.names, params.get("name"))
474
 
475
+ # Limit player-year comparisons using this window
476
+ window = self.window(None, "window")
477
+
478
  # Player metrics
479
  active, best, metrics = stats.metrics(name)
480
 
 
486
  self.chart(category, metrics)
487
 
488
  # Run search
489
+ results = stats.search(name, year, window)
490
 
491
  # Display results
492
  self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
 
503
  st.markdown("Find players with similar statistics.")
504
 
505
  stats, category = None, self.category("Batting", "searchcategory")
506
+
507
+ # Limit player-year comparisons using this window
508
+ window = self.window(None, "searchwindow")
509
+
510
  with st.form("search"):
511
  if category == "Batting":
512
  stats, columns = self.batting, self.batting.columns[:-6]
 
519
  submitted = st.form_submit_button("Search")
520
  if submitted:
521
  # Run search
522
+ results = stats.search(window=window, row=inputs.to_dict(orient="records")[0])
523
 
524
  # Display table
525
  self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
 
533
  """
534
 
535
  # Get parameters
536
+ params = {x: st.query_params.get(x) for x in ["category", "name", "window", "year"]}
537
 
538
  # Sync parameters with session state
539
+ if all(x in st.session_state for x in ["category", "name", "window", "year"]):
540
  # Copy session year if category and name are unchanged
541
  params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
542
 
543
  # Copy category and name from session state
544
  params["category"] = st.session_state["category"]
545
+ params["window"] = st.session_state["window"]
546
  params["name"] = st.session_state["name"]
547
 
548
  return params
 
568
  # Radio box component
569
  return st.radio("Stat", categories, index=default, horizontal=True, key=key)
570
 
571
+ def window(self, window, key):
572
+ """
573
+ Limit results to last N seasons.
574
+
575
+ Args:
576
+ window: limit to window seasons
577
+ key: widget key
578
+
579
+ Returns:
580
+ window component
581
+ """
582
+
583
+ return st.number_input("Limit to last N seasons", value=window, step=1, min_value=1, max_value=100, key=key)
584
+
585
  def name(self, names, name):
586
  """
587
  Builds name input widget.
 
697
 
698
  # Create and run application
699
  app = create()
700
+ app.run()