Update app.py
Browse files
app.py
CHANGED
|
@@ -38,7 +38,7 @@ class Stats:
|
|
| 38 |
self.names = self.loadnames()
|
| 39 |
|
| 40 |
# Build index
|
| 41 |
-
self.vectors, self.data, self.embeddings = self.index()
|
| 42 |
|
| 43 |
def loadcolumns(self):
|
| 44 |
"""
|
|
@@ -122,12 +122,12 @@ class Stats:
|
|
| 122 |
# Build data dictionary
|
| 123 |
vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
|
| 124 |
data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
|
|
|
|
| 125 |
|
| 126 |
embeddings = Embeddings({"transform": Stats.transform})
|
| 127 |
-
|
| 128 |
embeddings.index((uid, vectors[uid], None) for uid in vectors)
|
| 129 |
|
| 130 |
-
return vectors, data, embeddings
|
| 131 |
|
| 132 |
def metrics(self, name):
|
| 133 |
"""
|
|
@@ -155,13 +155,14 @@ class Stats:
|
|
| 155 |
|
| 156 |
return range(1871, datetime.datetime.today().year), 1950, None
|
| 157 |
|
| 158 |
-
def search(self, name=None, year=None, row=None, limit=10):
|
| 159 |
"""
|
| 160 |
Runs an embeddings search. This method takes either a player-year or stats row as input.
|
| 161 |
|
| 162 |
Args:
|
| 163 |
name: player name to search
|
| 164 |
year: year to search
|
|
|
|
| 165 |
row: row of stats to search
|
| 166 |
limit: max results to return
|
| 167 |
|
|
@@ -179,13 +180,17 @@ class Stats:
|
|
| 179 |
|
| 180 |
results, ids = [], set()
|
| 181 |
if query is not None:
|
| 182 |
-
|
|
|
|
| 183 |
# Only add unique players
|
| 184 |
if uid[4:] not in ids:
|
| 185 |
result = self.data[uid].copy()
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
if len(ids) >= limit:
|
| 191 |
break
|
|
@@ -467,6 +472,9 @@ class Application:
|
|
| 467 |
# Player name
|
| 468 |
name = self.name(stats.names, params.get("name"))
|
| 469 |
|
|
|
|
|
|
|
|
|
|
| 470 |
# Player metrics
|
| 471 |
active, best, metrics = stats.metrics(name)
|
| 472 |
|
|
@@ -478,7 +486,7 @@ class Application:
|
|
| 478 |
self.chart(category, metrics)
|
| 479 |
|
| 480 |
# Run search
|
| 481 |
-
results = stats.search(name, year)
|
| 482 |
|
| 483 |
# Display results
|
| 484 |
self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
|
|
@@ -495,6 +503,10 @@ class Application:
|
|
| 495 |
st.markdown("Find players with similar statistics.")
|
| 496 |
|
| 497 |
stats, category = None, self.category("Batting", "searchcategory")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
with st.form("search"):
|
| 499 |
if category == "Batting":
|
| 500 |
stats, columns = self.batting, self.batting.columns[:-6]
|
|
@@ -507,7 +519,7 @@ class Application:
|
|
| 507 |
submitted = st.form_submit_button("Search")
|
| 508 |
if submitted:
|
| 509 |
# Run search
|
| 510 |
-
results = stats.search(row=inputs.to_dict(orient="records")[0])
|
| 511 |
|
| 512 |
# Display table
|
| 513 |
self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
|
|
@@ -521,15 +533,16 @@ class Application:
|
|
| 521 |
"""
|
| 522 |
|
| 523 |
# Get parameters
|
| 524 |
-
params = {x: st.query_params.get(x) for x in ["category", "name", "year"]}
|
| 525 |
|
| 526 |
# Sync parameters with session state
|
| 527 |
-
if all(x in st.session_state for x in ["category", "name", "year"]):
|
| 528 |
# Copy session year if category and name are unchanged
|
| 529 |
params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
|
| 530 |
|
| 531 |
# Copy category and name from session state
|
| 532 |
params["category"] = st.session_state["category"]
|
|
|
|
| 533 |
params["name"] = st.session_state["name"]
|
| 534 |
|
| 535 |
return params
|
|
@@ -555,6 +568,20 @@ class Application:
|
|
| 555 |
# Radio box component
|
| 556 |
return st.radio("Stat", categories, index=default, horizontal=True, key=key)
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
def name(self, names, name):
|
| 559 |
"""
|
| 560 |
Builds name input widget.
|
|
@@ -670,4 +697,4 @@ if __name__ == "__main__":
|
|
| 670 |
|
| 671 |
# Create and run application
|
| 672 |
app = create()
|
| 673 |
-
app.run()
|
|
|
|
| 38 |
self.names = self.loadnames()
|
| 39 |
|
| 40 |
# Build index
|
| 41 |
+
self.vectors, self.data, self.maxyear, self.embeddings = self.index()
|
| 42 |
|
| 43 |
def loadcolumns(self):
|
| 44 |
"""
|
|
|
|
| 122 |
# Build data dictionary
|
| 123 |
vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
|
| 124 |
data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
|
| 125 |
+
maxyear = max(row["yearID"] for _, row in self.stats.iterrows())
|
| 126 |
|
| 127 |
embeddings = Embeddings({"transform": Stats.transform})
|
|
|
|
| 128 |
embeddings.index((uid, vectors[uid], None) for uid in vectors)
|
| 129 |
|
| 130 |
+
return vectors, data, maxyear, embeddings
|
| 131 |
|
| 132 |
def metrics(self, name):
|
| 133 |
"""
|
|
|
|
| 155 |
|
| 156 |
return range(1871, datetime.datetime.today().year), 1950, None
|
| 157 |
|
| 158 |
+
def search(self, name=None, year=None, window=None, row=None, limit=10):
|
| 159 |
"""
|
| 160 |
Runs an embeddings search. This method takes either a player-year or stats row as input.
|
| 161 |
|
| 162 |
Args:
|
| 163 |
name: player name to search
|
| 164 |
year: year to search
|
| 165 |
+
window: limit to window recent seasons
|
| 166 |
row: row of stats to search
|
| 167 |
limit: max results to return
|
| 168 |
|
|
|
|
| 180 |
|
| 181 |
results, ids = [], set()
|
| 182 |
if query is not None:
|
| 183 |
+
candidates = limit * 100 if window else limit * 5
|
| 184 |
+
for uid, _ in self.embeddings.search(query, candidates):
|
| 185 |
# Only add unique players
|
| 186 |
if uid[4:] not in ids:
|
| 187 |
result = self.data[uid].copy()
|
| 188 |
+
|
| 189 |
+
# Add first player if this is a player comparison. Limit results to window, if necessary
|
| 190 |
+
if (not ids and not row) or not window or result["yearID"] > self.maxyear - window:
|
| 191 |
+
result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
|
| 192 |
+
results.append(result)
|
| 193 |
+
ids.add(uid[4:])
|
| 194 |
|
| 195 |
if len(ids) >= limit:
|
| 196 |
break
|
|
|
|
| 472 |
# Player name
|
| 473 |
name = self.name(stats.names, params.get("name"))
|
| 474 |
|
| 475 |
+
# Limit player-year comparisons using this window
|
| 476 |
+
window = self.window(None, "window")
|
| 477 |
+
|
| 478 |
# Player metrics
|
| 479 |
active, best, metrics = stats.metrics(name)
|
| 480 |
|
|
|
|
| 486 |
self.chart(category, metrics)
|
| 487 |
|
| 488 |
# Run search
|
| 489 |
+
results = stats.search(name, year, window)
|
| 490 |
|
| 491 |
# Display results
|
| 492 |
self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
|
|
|
|
| 503 |
st.markdown("Find players with similar statistics.")
|
| 504 |
|
| 505 |
stats, category = None, self.category("Batting", "searchcategory")
|
| 506 |
+
|
| 507 |
+
# Limit player-year comparisons using this window
|
| 508 |
+
window = self.window(None, "searchwindow")
|
| 509 |
+
|
| 510 |
with st.form("search"):
|
| 511 |
if category == "Batting":
|
| 512 |
stats, columns = self.batting, self.batting.columns[:-6]
|
|
|
|
| 519 |
submitted = st.form_submit_button("Search")
|
| 520 |
if submitted:
|
| 521 |
# Run search
|
| 522 |
+
results = stats.search(window=window, row=inputs.to_dict(orient="records")[0])
|
| 523 |
|
| 524 |
# Display table
|
| 525 |
self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
|
|
|
|
| 533 |
"""
|
| 534 |
|
| 535 |
# Get parameters
|
| 536 |
+
params = {x: st.query_params.get(x) for x in ["category", "name", "window", "year"]}
|
| 537 |
|
| 538 |
# Sync parameters with session state
|
| 539 |
+
if all(x in st.session_state for x in ["category", "name", "window", "year"]):
|
| 540 |
# Copy session year if category and name are unchanged
|
| 541 |
params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
|
| 542 |
|
| 543 |
# Copy category and name from session state
|
| 544 |
params["category"] = st.session_state["category"]
|
| 545 |
+
params["window"] = st.session_state["window"]
|
| 546 |
params["name"] = st.session_state["name"]
|
| 547 |
|
| 548 |
return params
|
|
|
|
| 568 |
# Radio box component
|
| 569 |
return st.radio("Stat", categories, index=default, horizontal=True, key=key)
|
| 570 |
|
| 571 |
+
def window(self, window, key):
|
| 572 |
+
"""
|
| 573 |
+
Limit results to last N seasons.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
window: limit to window seasons
|
| 577 |
+
key: widget key
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
window component
|
| 581 |
+
"""
|
| 582 |
+
|
| 583 |
+
return st.number_input("Limit to last N seasons", value=window, step=1, min_value=1, max_value=100, key=key)
|
| 584 |
+
|
| 585 |
def name(self, names, name):
|
| 586 |
"""
|
| 587 |
Builds name input widget.
|
|
|
|
| 697 |
|
| 698 |
# Create and run application
|
| 699 |
app = create()
|
| 700 |
+
app.run()
|