Spaces:
Sleeping
Sleeping
Hatmanstack Claude Opus 4.5 commited on
Commit ·
6424951
1
Parent(s): a63f84a
Refactor app with security fixes, error handling, and type safety
Browse files- Fix SQL injection: use parameterized queries in src/database/queries.py
- Fix XSS: add HTML escaping in src/utils/html.py
- Re-enable CORS/XSRF protection in devcontainer.json
- Add Pydantic models for validation in src/models/player.py
- Add session state management with safe defaults
- Cache ML model loading with @st .cache_resource
- Fix N+1 query with batch IN clause
- Add loop guards to prevent infinite loops
- Add comprehensive test suite
- Add pyproject.toml with mypy strict mode and ruff config
- Move compile_model.py to scripts/ with main() guard
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- .devcontainer/devcontainer.json +1 -1
- app.py +15 -5
- compile_model.py +0 -73
- pages/1_home_team.py +128 -84
- pages/2_play_game.py +179 -99
- pyproject.toml +104 -0
- requirements-dev.txt +6 -0
- requirements.txt +6 -4
- scripts/compile_model.py +243 -0
- src/__init__.py +1 -0
- src/config.py +93 -0
- src/database/__init__.py +23 -0
- src/database/connection.py +111 -0
- src/database/queries.py +127 -0
- src/ml/__init__.py +15 -0
- src/ml/model.py +114 -0
- src/models/__init__.py +5 -0
- src/models/player.py +144 -0
- src/state/__init__.py +5 -0
- src/state/session.py +160 -0
- src/utils/__init__.py +5 -0
- src/utils/html.py +108 -0
- src/validation/__init__.py +9 -0
- src/validation/inputs.py +111 -0
- tests/__init__.py +1 -0
- tests/conftest.py +123 -0
- tests/test_database.py +220 -0
- tests/test_ml.py +147 -0
- tests/test_models.py +140 -0
- tests/test_validation.py +130 -0
.devcontainer/devcontainer.json
CHANGED
|
@@ -19,7 +19,7 @@
|
|
| 19 |
},
|
| 20 |
"updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
|
| 21 |
"postAttachCommand": {
|
| 22 |
-
"server": "streamlit run app.py
|
| 23 |
},
|
| 24 |
"portsAttributes": {
|
| 25 |
"8501": {
|
|
|
|
| 19 |
},
|
| 20 |
"updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
|
| 21 |
"postAttachCommand": {
|
| 22 |
+
"server": "streamlit run app.py"
|
| 23 |
},
|
| 24 |
"portsAttributes": {
|
| 25 |
"8501": {
|
app.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import snowflake.connector
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
| 7 |
on_page_load()
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
|
|
|
| 1 |
+
"""NBA Team Builder Application - Entry Point."""
|
| 2 |
+
|
| 3 |
import streamlit as st
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
from src.utils.html import safe_heading, safe_paragraph
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def on_page_load() -> None:
|
| 9 |
+
"""Configure page settings."""
|
| 10 |
st.set_page_config(layout="wide")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
on_page_load()
|
| 14 |
|
| 15 |
+
safe_heading("NBA", level=1, color="steelblue")
|
| 16 |
|
| 17 |
+
safe_paragraph(
|
| 18 |
+
"A Simple app to test your skill in building a Team based on "
|
| 19 |
+
"career stats to compete with a Computer",
|
| 20 |
+
color="white",
|
| 21 |
+
)
|
| 22 |
|
| 23 |
|
compile_model.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import numpy as np
|
| 3 |
-
from tensorflow import keras
|
| 4 |
-
from tensorflow.keras import layers
|
| 5 |
-
from tensorflow.keras.losses import BinaryCrossentropy
|
| 6 |
-
from sklearn.model_selection import train_test_split
|
| 7 |
-
from sklearn.model_selection import RandomizedSearchCV
|
| 8 |
-
from scikeras.wrappers import KerasClassifier
|
| 9 |
-
|
| 10 |
-
def create_stats(roster, schedule):
|
| 11 |
-
home_stats = []
|
| 12 |
-
away_stats = []
|
| 13 |
-
S = []
|
| 14 |
-
|
| 15 |
-
# Loading Relavent Columns from f-test
|
| 16 |
-
cols = ['TEAM','PTS/G', 'ORB', 'DRB', 'AST', 'STL', 'BLK', 'TOV', '3P%', 'FT%','2P']
|
| 17 |
-
new_roster = roster[cols]
|
| 18 |
-
for i in schedule['Home/Neutral']:
|
| 19 |
-
home_stats.append((new_roster[new_roster['TEAM'] == i]).values.tolist())
|
| 20 |
-
for i in schedule['Visitor/Neutral']:
|
| 21 |
-
away_stats.append((new_roster.loc[new_roster['TEAM'] == i]).values.tolist())
|
| 22 |
-
for i in range(len(home_stats)):
|
| 23 |
-
arr = []
|
| 24 |
-
for j in range(len(home_stats[i])):
|
| 25 |
-
del home_stats[i][j][0]
|
| 26 |
-
arr += home_stats[i][j]
|
| 27 |
-
for j in range(len(away_stats[i])):
|
| 28 |
-
del away_stats[i][j][0]
|
| 29 |
-
arr += away_stats[i][j]
|
| 30 |
-
|
| 31 |
-
# Create numpy array with all the players on the Home Team's Stats followed by the Away Team's stats
|
| 32 |
-
S.append(np.nan_to_num(np.array(arr), copy=False))
|
| 33 |
-
return S
|
| 34 |
-
|
| 35 |
-
roster = pd.read_csv('player_stats.txt', delimiter=',')
|
| 36 |
-
schedule = pd.read_csv('schedule.txt', delimiter=',')
|
| 37 |
-
|
| 38 |
-
# Create winning condition to train on
|
| 39 |
-
schedule['winner'] = schedule.apply(lambda x: 0 if x['PTS'] > x['PTS.1'] else 1, axis=1)
|
| 40 |
-
|
| 41 |
-
X = np.array(create_stats(roster, schedule))
|
| 42 |
-
y = np.array(schedule['winner'])
|
| 43 |
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 44 |
-
|
| 45 |
-
def create_model(optimizer='rmsprop', init='glorot_uniform'):
|
| 46 |
-
inputs = keras.Input(shape=(100,))
|
| 47 |
-
dense = layers.Dense(50, activation="relu")
|
| 48 |
-
x = dense(inputs)
|
| 49 |
-
x = layers.Dense(64, activation="relu")(x)
|
| 50 |
-
outputs = layers.Dense(1, activation='sigmoid')(x)
|
| 51 |
-
model = keras.Model(inputs=inputs, outputs=outputs, name="nba_model")
|
| 52 |
-
model.compile(loss=BinaryCrossentropy(from_logits=False), optimizer=optimizer, metrics=["accuracy"])
|
| 53 |
-
|
| 54 |
-
return model
|
| 55 |
-
|
| 56 |
-
model = KerasClassifier(model=create_model, verbose=0, init='glorot_uniform')
|
| 57 |
-
|
| 58 |
-
optimizer = ['SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam']
|
| 59 |
-
init = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']
|
| 60 |
-
epochs = [500, 1000, 1500]
|
| 61 |
-
batches = [50, 100, 200]
|
| 62 |
-
param_grid = dict(optimizer=optimizer, epochs=epochs, batch_size=batches, init=init)
|
| 63 |
-
|
| 64 |
-
random_search = RandomizedSearchCV(estimator=model, param_distributions=param_grid, n_iter=100, verbose=3)
|
| 65 |
-
random_search_result = random_search.fit(X_train, y_train)
|
| 66 |
-
best_model = random_search_result.best_estimator_
|
| 67 |
-
|
| 68 |
-
best_model.model_.save('winner.keras')
|
| 69 |
-
best_parameters = random_search_result.best_params_
|
| 70 |
-
print("Best parameters: ", best_parameters)
|
| 71 |
-
|
| 72 |
-
test_accuracy = random_search_result.best_estimator_.score(X_test, y_test)
|
| 73 |
-
print("Test accuracy: ", test_accuracy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/1_home_team.py
CHANGED
|
@@ -1,111 +1,155 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
| 7 |
on_page_load()
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
col1, col2, col3 = st.columns(3)
|
| 10 |
|
| 11 |
with col2:
|
| 12 |
-
|
| 13 |
-
player_add = st.text_input(
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
home_team_df = find_home_team()
|
| 55 |
|
| 56 |
-
|
| 57 |
if not home_team_df.empty:
|
| 58 |
-
name_list = home_team_df[
|
| 59 |
-
player_search
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
holder = saved_players + player_selected
|
|
|
|
| 64 |
if len(player_selected) > len(saved_players):
|
| 65 |
-
for
|
| 66 |
-
if
|
| 67 |
-
st.session_state.home_team.append(
|
| 68 |
elif len(player_selected) < len(saved_players):
|
| 69 |
-
for
|
| 70 |
-
if
|
| 71 |
-
st.session_state.home_team.remove(
|
| 72 |
st.rerun()
|
| 73 |
|
| 74 |
-
|
|
|
|
| 75 |
with col1:
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with col2:
|
| 78 |
-
if st.button(
|
| 79 |
save_state()
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
st.dataframe(home_team_df)
|
| 84 |
-
|
|
|
|
| 85 |
col1, col2, col3, col4, col5 = st.columns(5)
|
|
|
|
| 86 |
with col3:
|
| 87 |
-
|
| 88 |
difficulty = st.radio(
|
| 89 |
-
label="Difficulty",
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
st.session_state.away_stats = [
|
| 97 |
-
st.session_state.radio_index =
|
| 98 |
-
elif difficulty == 'All-Stars':
|
| 99 |
-
st.session_state.away_stats = [1250, 600, 400, 100]
|
| 100 |
-
st.session_state.radio_index = 2
|
| 101 |
-
elif difficulty == 'Dream Team':
|
| 102 |
-
st.session_state.away_stats = [1450, 700, 500, 120]
|
| 103 |
-
st.session_state.radio_index = 3
|
| 104 |
else:
|
| 105 |
st.write("You didn't select a difficulty.")
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
| 1 |
+
"""Home team builder page."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
import pandas as pd
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
+
from src.config import DIFFICULTY_PRESETS, PLAYER_COLUMNS
|
| 9 |
+
from src.database.connection import DatabaseConnectionError, QueryExecutionError, get_connection
|
| 10 |
+
from src.database.queries import get_players_by_full_names, search_player_by_name
|
| 11 |
+
from src.state.session import init_session_state
|
| 12 |
+
from src.utils.html import safe_heading, safe_paragraph
|
| 13 |
+
from src.validation.inputs import validate_search_term
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("streamlit_nba")
|
| 16 |
|
| 17 |
+
|
| 18 |
+
def on_page_load() -> None:
|
| 19 |
+
"""Configure page settings."""
|
| 20 |
st.set_page_config(layout="wide")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
on_page_load()
|
| 24 |
|
| 25 |
+
# Initialize session state before any access
|
| 26 |
+
init_session_state()
|
| 27 |
+
|
| 28 |
col1, col2, col3 = st.columns(3)
|
| 29 |
|
| 30 |
with col2:
|
| 31 |
+
safe_heading("Build Your Team", level=1, color="steelblue")
|
| 32 |
+
player_add = st.text_input("Who're you picking?", "James")
|
| 33 |
+
|
| 34 |
+
safe_paragraph(
|
| 35 |
+
"Search for a player to populate the dropdown menu then pick and "
|
| 36 |
+
"save your team before searching for another player.",
|
| 37 |
+
color="steelblue",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def find_player(search_term: str) -> list[str]:
|
| 42 |
+
"""Search for players by name with validation and error handling.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
search_term: User-provided search term
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
List of matching player full names
|
| 49 |
+
"""
|
| 50 |
+
# Validate input
|
| 51 |
+
validated_term = validate_search_term(search_term)
|
| 52 |
+
if validated_term is None:
|
| 53 |
+
st.warning("Invalid search term. Please use only letters, numbers, and basic punctuation.")
|
| 54 |
+
return []
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
with get_connection() as conn:
|
| 58 |
+
results = search_player_by_name(conn, validated_term)
|
| 59 |
+
return [player[0] for player in results]
|
| 60 |
+
except DatabaseConnectionError as e:
|
| 61 |
+
st.error(f"Could not connect to database. Please try again later.")
|
| 62 |
+
logger.error(f"Database connection error: {e}")
|
| 63 |
+
return []
|
| 64 |
+
except QueryExecutionError as e:
|
| 65 |
+
st.error("Error searching for players. Please try again.")
|
| 66 |
+
logger.error(f"Query error: {e}")
|
| 67 |
+
return []
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def find_home_team() -> pd.DataFrame:
|
| 71 |
+
"""Load home team data from database using batch query.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
DataFrame with home team player data
|
| 75 |
+
"""
|
| 76 |
+
team_names: list[str] = st.session_state.get("home_team", [])
|
| 77 |
+
if not team_names:
|
| 78 |
+
return pd.DataFrame(columns=PLAYER_COLUMNS)
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
with get_connection() as conn:
|
| 82 |
+
# Single batch query instead of N+1 queries
|
| 83 |
+
df = get_players_by_full_names(conn, team_names)
|
| 84 |
+
st.session_state.home_team_df = df
|
| 85 |
+
return df
|
| 86 |
+
except DatabaseConnectionError as e:
|
| 87 |
+
st.error("Could not connect to database. Please try again later.")
|
| 88 |
+
logger.error(f"Database connection error: {e}")
|
| 89 |
+
return pd.DataFrame(columns=PLAYER_COLUMNS)
|
| 90 |
+
except QueryExecutionError as e:
|
| 91 |
+
st.error("Error loading team data. Please try again.")
|
| 92 |
+
logger.error(f"Query error: {e}")
|
| 93 |
+
return pd.DataFrame(columns=PLAYER_COLUMNS)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Load data
|
| 97 |
+
player_search = find_player(player_add)
|
| 98 |
home_team_df = find_home_team()
|
| 99 |
|
| 100 |
+
# Combine search results with current team
|
| 101 |
if not home_team_df.empty:
|
| 102 |
+
name_list = home_team_df["FULL_NAME"].tolist()
|
| 103 |
+
player_search = player_search + [n for n in name_list if n not in player_search]
|
| 104 |
|
| 105 |
+
|
| 106 |
+
def save_state() -> None:
|
| 107 |
+
"""Save the selected players to session state."""
|
| 108 |
+
saved_players = home_team_df["FULL_NAME"].tolist() if not home_team_df.empty else []
|
| 109 |
holder = saved_players + player_selected
|
| 110 |
+
|
| 111 |
if len(player_selected) > len(saved_players):
|
| 112 |
+
for player in holder:
|
| 113 |
+
if player not in st.session_state.home_team:
|
| 114 |
+
st.session_state.home_team.append(player)
|
| 115 |
elif len(player_selected) < len(saved_players):
|
| 116 |
+
for player in saved_players:
|
| 117 |
+
if player not in player_selected:
|
| 118 |
+
st.session_state.home_team.remove(player)
|
| 119 |
st.rerun()
|
| 120 |
|
| 121 |
+
|
| 122 |
+
col1, col2 = st.columns([7, 1])
|
| 123 |
with col1:
|
| 124 |
+
default_selection = home_team_df["FULL_NAME"].tolist() if not home_team_df.empty else []
|
| 125 |
+
player_selected = st.multiselect(
|
| 126 |
+
"Search Results:",
|
| 127 |
+
player_search,
|
| 128 |
+
default_selection,
|
| 129 |
+
label_visibility="collapsed",
|
| 130 |
+
)
|
| 131 |
with col2:
|
| 132 |
+
if st.button("Save Team"):
|
| 133 |
save_state()
|
| 134 |
|
| 135 |
+
safe_heading("Preview", level=1, color="steelblue")
|
| 136 |
+
|
| 137 |
st.dataframe(home_team_df)
|
| 138 |
+
|
| 139 |
+
radio_index: int = st.session_state.get("radio_index", 0)
|
| 140 |
col1, col2, col3, col4, col5 = st.columns(5)
|
| 141 |
+
|
| 142 |
with col3:
|
| 143 |
+
safe_heading("Away Team", level=3, color="steelblue")
|
| 144 |
difficulty = st.radio(
|
| 145 |
+
label="Difficulty",
|
| 146 |
+
index=radio_index,
|
| 147 |
+
options=list(DIFFICULTY_PRESETS.keys()),
|
| 148 |
+
label_visibility="collapsed",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if difficulty and difficulty in DIFFICULTY_PRESETS:
|
| 152 |
+
st.session_state.away_stats = list(DIFFICULTY_PRESETS[difficulty])
|
| 153 |
+
st.session_state.radio_index = list(DIFFICULTY_PRESETS.keys()).index(difficulty)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
else:
|
| 155 |
st.write("You didn't select a difficulty.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/2_play_game.py
CHANGED
|
@@ -1,114 +1,194 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
import numpy as np
|
| 5 |
-
import tensorflow as tf
|
| 6 |
import random
|
| 7 |
-
from tensorflow.keras.models import load_model
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
| 12 |
on_page_load()
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
teams_good = True
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
away_data = pd.DataFrame()
|
| 45 |
teams_good = False
|
| 46 |
-
|
|
|
|
| 47 |
else:
|
| 48 |
-
away_data = find_away_team()
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
if
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
col1, col2, col3 = st.columns(3)
|
| 107 |
with col2:
|
| 108 |
-
|
| 109 |
-
|
|
|
|
| 110 |
st.dataframe(away_data)
|
| 111 |
|
| 112 |
if st.button("Play New Team"):
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 1 |
+
"""Game play page with prediction and scoring."""
|
| 2 |
+
|
| 3 |
+
import logging
|
|
|
|
|
|
|
| 4 |
import random
|
|
|
|
| 5 |
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import streamlit as st
|
| 8 |
|
| 9 |
+
from src.config import (
|
| 10 |
+
DEFAULT_LOSER_SCORE,
|
| 11 |
+
DEFAULT_WINNER_SCORE,
|
| 12 |
+
LOSER_SCORE_RANGE,
|
| 13 |
+
MAX_QUERY_ATTEMPTS,
|
| 14 |
+
STAT_COLUMNS,
|
| 15 |
+
TEAM_SIZE,
|
| 16 |
+
WINNER_SCORE_RANGE,
|
| 17 |
+
)
|
| 18 |
+
from src.database.connection import DatabaseConnectionError, QueryExecutionError, get_connection
|
| 19 |
+
from src.database.queries import get_away_team_by_stats
|
| 20 |
+
from src.ml.model import ModelLoadError, analyze_team_stats, predict_winner
|
| 21 |
+
from src.state.session import get_away_stats, get_home_team_df, init_session_state
|
| 22 |
+
from src.utils.html import safe_heading
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("streamlit_nba")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def on_page_load() -> None:
|
| 28 |
+
"""Configure page settings."""
|
| 29 |
st.set_page_config(layout="wide")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
on_page_load()
|
| 33 |
|
| 34 |
+
# Initialize session state BEFORE any access
|
| 35 |
+
init_session_state()
|
| 36 |
+
|
| 37 |
+
# Get stats safely with fallback
|
| 38 |
+
stats = get_away_stats()
|
| 39 |
teams_good = True
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def find_away_team(stat_thresholds: list[int]) -> pd.DataFrame:
|
| 43 |
+
"""Generate away team based on difficulty stats.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
stat_thresholds: List of [pts, reb, ast, stl] thresholds
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
DataFrame with away team data, or empty DataFrame on error
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
with get_connection() as conn:
|
| 53 |
+
return get_away_team_by_stats(
|
| 54 |
+
conn,
|
| 55 |
+
pts_threshold=stat_thresholds[0],
|
| 56 |
+
reb_threshold=stat_thresholds[1],
|
| 57 |
+
ast_threshold=stat_thresholds[2],
|
| 58 |
+
stl_threshold=stat_thresholds[3],
|
| 59 |
+
max_attempts=MAX_QUERY_ATTEMPTS,
|
| 60 |
+
)
|
| 61 |
+
except DatabaseConnectionError as e:
|
| 62 |
+
st.error("Could not connect to database. Please try again later.")
|
| 63 |
+
logger.error(f"Database connection error: {e}")
|
| 64 |
+
return pd.DataFrame()
|
| 65 |
+
except QueryExecutionError as e:
|
| 66 |
+
st.error("Could not generate away team. Please try again.")
|
| 67 |
+
logger.error(f"Query error: {e}")
|
| 68 |
+
return pd.DataFrame()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_score_board(final_score: int) -> list[int]:
|
| 72 |
+
"""Generate quarter-by-quarter scores that sum to final score.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
final_score: Total game score
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
List of [Q1, Q2, Q3, Q4, Final] scores
|
| 79 |
+
"""
|
| 80 |
+
quarter_score = final_score // 4
|
| 81 |
+
scores = [
|
| 82 |
+
quarter_score + random.randint(-7, 7),
|
| 83 |
+
quarter_score + random.randint(-3, 3),
|
| 84 |
+
quarter_score + random.randint(-8, 8),
|
| 85 |
+
]
|
| 86 |
+
# Q4 makes up the difference to hit exact final
|
| 87 |
+
scores.append(final_score - sum(scores))
|
| 88 |
+
scores.append(final_score)
|
| 89 |
+
return scores
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def generate_game_scores() -> tuple[int, int]:
|
| 93 |
+
"""Generate winner and loser scores with loop guard.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Tuple of (winner_score, loser_score)
|
| 97 |
+
"""
|
| 98 |
+
for _ in range(MAX_QUERY_ATTEMPTS):
|
| 99 |
+
winner_score = random.randint(*WINNER_SCORE_RANGE)
|
| 100 |
+
loser_score = random.randint(*LOSER_SCORE_RANGE)
|
| 101 |
+
if winner_score > loser_score:
|
| 102 |
+
return winner_score, loser_score
|
| 103 |
+
|
| 104 |
+
# Fallback to guaranteed valid scores
|
| 105 |
+
logger.warning("Score generation fell back to defaults")
|
| 106 |
+
return DEFAULT_WINNER_SCORE, DEFAULT_LOSER_SCORE
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Check if home team is valid
|
| 110 |
+
home_team_df = get_home_team_df()
|
| 111 |
+
|
| 112 |
+
if home_team_df.empty or home_team_df.shape[0] != TEAM_SIZE:
|
| 113 |
+
safe_heading(
|
| 114 |
+
f"Your Team Doesn't Have {TEAM_SIZE} Players",
|
| 115 |
+
level=3,
|
| 116 |
+
color="red",
|
| 117 |
+
)
|
| 118 |
away_data = pd.DataFrame()
|
| 119 |
teams_good = False
|
| 120 |
+
winner_label = ""
|
| 121 |
+
box_score = pd.DataFrame()
|
| 122 |
else:
|
| 123 |
+
away_data = find_away_team(stats)
|
| 124 |
+
if away_data.empty:
|
| 125 |
+
teams_good = False
|
| 126 |
+
winner_label = ""
|
| 127 |
+
box_score = pd.DataFrame()
|
| 128 |
+
|
| 129 |
+
# Run prediction if both teams are valid
|
| 130 |
+
if teams_good and not away_data.empty:
|
| 131 |
+
try:
|
| 132 |
+
# Extract stats for ML model
|
| 133 |
+
home_stats = home_team_df[STAT_COLUMNS].values.tolist()
|
| 134 |
+
away_stats_data = away_data[STAT_COLUMNS].values.tolist()
|
| 135 |
+
|
| 136 |
+
# Prepare data and predict
|
| 137 |
+
_, _, combined = analyze_team_stats(home_stats, away_stats_data)
|
| 138 |
+
probability, prediction = predict_winner(combined)
|
| 139 |
+
|
| 140 |
+
# Generate scores
|
| 141 |
+
winner_score, loser_score = generate_game_scores()
|
| 142 |
+
|
| 143 |
+
# Build scoreboard based on prediction
|
| 144 |
+
if prediction == 1:
|
| 145 |
+
score_data = [
|
| 146 |
+
get_score_board(winner_score),
|
| 147 |
+
get_score_board(loser_score),
|
| 148 |
+
]
|
| 149 |
+
winner_label = "Winner"
|
| 150 |
+
else:
|
| 151 |
+
score_data = [
|
| 152 |
+
get_score_board(loser_score),
|
| 153 |
+
get_score_board(winner_score),
|
| 154 |
+
]
|
| 155 |
+
winner_label = "Loser"
|
| 156 |
+
|
| 157 |
+
box_score = pd.DataFrame(
|
| 158 |
+
score_data,
|
| 159 |
+
columns=["1", "2", "3", "4", "Final"],
|
| 160 |
+
index=["Home Team", "Away Team"],
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
logger.info(f"Prediction: {probability:.4f}")
|
| 164 |
+
|
| 165 |
+
except ModelLoadError as e:
|
| 166 |
+
st.error("Could not load prediction model. Please contact support.")
|
| 167 |
+
logger.error(f"Model load error: {e}")
|
| 168 |
+
teams_good = False
|
| 169 |
+
winner_label = ""
|
| 170 |
+
box_score = pd.DataFrame()
|
| 171 |
+
except ValueError as e:
|
| 172 |
+
st.error("Error processing team stats. Please try again.")
|
| 173 |
+
logger.error(f"Stats processing error: {e}")
|
| 174 |
+
teams_good = False
|
| 175 |
+
winner_label = ""
|
| 176 |
+
box_score = pd.DataFrame()
|
| 177 |
+
|
| 178 |
+
# Display results
|
| 179 |
+
safe_heading("Home Team", level=1, color="steelblue")
|
| 180 |
+
st.dataframe(home_team_df)
|
| 181 |
+
|
| 182 |
+
if teams_good and winner_label:
|
| 183 |
+
logger.info("Teams Good")
|
| 184 |
+
safe_heading(winner_label, level=3, color="steelblue")
|
| 185 |
col1, col2, col3 = st.columns(3)
|
| 186 |
with col2:
|
| 187 |
+
st.dataframe(box_score)
|
| 188 |
+
|
| 189 |
+
safe_heading("Away Team", level=1, color="steelblue")
|
| 190 |
st.dataframe(away_data)
|
| 191 |
|
| 192 |
if st.button("Play New Team"):
|
| 193 |
+
logger.info("New Team requested")
|
| 194 |
+
st.rerun()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "streamlit-nba"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "NBA team builder and game prediction Streamlit application"
|
| 5 |
+
requires-python = ">=3.11"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"streamlit>=1.28.0",
|
| 8 |
+
"snowflake-connector-python>=3.5.0",
|
| 9 |
+
"tensorflow>=2.15.0",
|
| 10 |
+
"numpy>=1.24.0",
|
| 11 |
+
"pandas>=2.0.0",
|
| 12 |
+
"pydantic>=2.5.0",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[project.optional-dependencies]
|
| 16 |
+
dev = [
|
| 17 |
+
"pytest>=7.4.0",
|
| 18 |
+
"pytest-cov>=4.1.0",
|
| 19 |
+
"mypy>=1.7.0",
|
| 20 |
+
"ruff>=0.1.6",
|
| 21 |
+
"pandas-stubs>=2.0.0",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[tool.mypy]
|
| 25 |
+
python_version = "3.11"
|
| 26 |
+
strict = true
|
| 27 |
+
warn_return_any = true
|
| 28 |
+
warn_unused_configs = true
|
| 29 |
+
disallow_untyped_defs = true
|
| 30 |
+
disallow_incomplete_defs = true
|
| 31 |
+
check_untyped_defs = true
|
| 32 |
+
disallow_untyped_decorators = true
|
| 33 |
+
no_implicit_optional = true
|
| 34 |
+
warn_redundant_casts = true
|
| 35 |
+
warn_unused_ignores = true
|
| 36 |
+
warn_no_return = true
|
| 37 |
+
warn_unreachable = true
|
| 38 |
+
|
| 39 |
+
# Third-party library ignores
|
| 40 |
+
[[tool.mypy.overrides]]
|
| 41 |
+
module = [
|
| 42 |
+
"streamlit.*",
|
| 43 |
+
"snowflake.*",
|
| 44 |
+
"tensorflow.*",
|
| 45 |
+
"keras.*",
|
| 46 |
+
"sklearn.*",
|
| 47 |
+
"scikeras.*",
|
| 48 |
+
]
|
| 49 |
+
ignore_missing_imports = true
|
| 50 |
+
|
| 51 |
+
[tool.ruff]
|
| 52 |
+
target-version = "py311"
|
| 53 |
+
line-length = 88
|
| 54 |
+
|
| 55 |
+
[tool.ruff.lint]
|
| 56 |
+
select = [
|
| 57 |
+
"E", # pycodestyle errors
|
| 58 |
+
"W", # pycodestyle warnings
|
| 59 |
+
"F", # pyflakes
|
| 60 |
+
"I", # isort
|
| 61 |
+
"B", # flake8-bugbear
|
| 62 |
+
"C4", # flake8-comprehensions
|
| 63 |
+
"UP", # pyupgrade
|
| 64 |
+
"ARG", # flake8-unused-arguments
|
| 65 |
+
"SIM", # flake8-simplify
|
| 66 |
+
"TCH", # flake8-type-checking
|
| 67 |
+
"PTH", # flake8-use-pathlib
|
| 68 |
+
"PL", # pylint
|
| 69 |
+
"RUF", # ruff-specific
|
| 70 |
+
"S", # flake8-bandit (security)
|
| 71 |
+
]
|
| 72 |
+
ignore = [
|
| 73 |
+
"S101", # assert used (ok in tests)
|
| 74 |
+
"PLR0913", # too many arguments
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
[tool.ruff.lint.per-file-ignores]
|
| 78 |
+
"tests/*" = ["S101", "ARG001", "PLR2004"]
|
| 79 |
+
|
| 80 |
+
[tool.pytest.ini_options]
|
| 81 |
+
testpaths = ["tests"]
|
| 82 |
+
python_files = ["test_*.py"]
|
| 83 |
+
python_functions = ["test_*"]
|
| 84 |
+
addopts = [
|
| 85 |
+
"-v",
|
| 86 |
+
"--tb=short",
|
| 87 |
+
"--strict-markers",
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
[tool.coverage.run]
|
| 91 |
+
source = ["src"]
|
| 92 |
+
branch = true
|
| 93 |
+
omit = [
|
| 94 |
+
"tests/*",
|
| 95 |
+
"scripts/*",
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
[tool.coverage.report]
|
| 99 |
+
exclude_lines = [
|
| 100 |
+
"pragma: no cover",
|
| 101 |
+
"if TYPE_CHECKING:",
|
| 102 |
+
"raise NotImplementedError",
|
| 103 |
+
]
|
| 104 |
+
fail_under = 80
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-r requirements.txt
|
| 2 |
+
pytest>=7.4.0
|
| 3 |
+
pytest-cov>=4.1.0
|
| 4 |
+
mypy>=1.7.0
|
| 5 |
+
ruff>=0.1.6
|
| 6 |
+
pandas-stubs>=2.0.0
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
snowflake-connector-python>=3.5.0
|
| 3 |
+
tensorflow>=2.15.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
pandas>=2.0.0
|
| 6 |
+
pydantic>=2.5.0
|
scripts/compile_model.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""NBA game winner prediction model training script.
|
| 3 |
+
|
| 4 |
+
This script trains a neural network to predict game winners based on
|
| 5 |
+
team statistics. It uses RandomizedSearchCV to find optimal hyperparameters.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/compile_model.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
from scikeras.wrappers import KerasClassifier
|
| 17 |
+
from sklearn.model_selection import RandomizedSearchCV, train_test_split
|
| 18 |
+
from tensorflow import keras
|
| 19 |
+
from tensorflow.keras import layers
|
| 20 |
+
from tensorflow.keras.losses import BinaryCrossentropy
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
# Data file paths
|
| 30 |
+
ROSTER_FILE = Path("player_stats.txt")
|
| 31 |
+
SCHEDULE_FILE = Path("schedule.txt")
|
| 32 |
+
OUTPUT_MODEL = Path("winner.keras")
|
| 33 |
+
|
| 34 |
+
# Feature columns from roster data
|
| 35 |
+
FEATURE_COLS: list[str] = [
|
| 36 |
+
"TEAM",
|
| 37 |
+
"PTS/G",
|
| 38 |
+
"ORB",
|
| 39 |
+
"DRB",
|
| 40 |
+
"AST",
|
| 41 |
+
"STL",
|
| 42 |
+
"BLK",
|
| 43 |
+
"TOV",
|
| 44 |
+
"3P%",
|
| 45 |
+
"FT%",
|
| 46 |
+
"2P",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# Hyperparameter search space
|
| 50 |
+
OPTIMIZERS: list[str] = [
|
| 51 |
+
"SGD",
|
| 52 |
+
"RMSprop",
|
| 53 |
+
"Adagrad",
|
| 54 |
+
"Adadelta",
|
| 55 |
+
"Adam",
|
| 56 |
+
"Adamax",
|
| 57 |
+
"Nadam",
|
| 58 |
+
]
|
| 59 |
+
INITIALIZERS: list[str] = [
|
| 60 |
+
"uniform",
|
| 61 |
+
"lecun_uniform",
|
| 62 |
+
"normal",
|
| 63 |
+
"zero",
|
| 64 |
+
"glorot_normal",
|
| 65 |
+
"glorot_uniform",
|
| 66 |
+
"he_normal",
|
| 67 |
+
"he_uniform",
|
| 68 |
+
]
|
| 69 |
+
EPOCHS: list[int] = [500, 1000, 1500]
|
| 70 |
+
BATCH_SIZES: list[int] = [50, 100, 200]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def create_stats(
|
| 74 |
+
roster: pd.DataFrame, schedule: pd.DataFrame
|
| 75 |
+
) -> list[np.ndarray]:
|
| 76 |
+
"""Create feature arrays from roster and schedule data.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
roster: DataFrame with player statistics
|
| 80 |
+
schedule: DataFrame with game schedule and scores
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List of numpy arrays, one per game with combined team stats
|
| 84 |
+
"""
|
| 85 |
+
home_stats: list[list] = []
|
| 86 |
+
away_stats: list[list] = []
|
| 87 |
+
features: list[np.ndarray] = []
|
| 88 |
+
|
| 89 |
+
new_roster = roster[FEATURE_COLS]
|
| 90 |
+
|
| 91 |
+
# Get stats for each team in each game
|
| 92 |
+
for team in schedule["Home/Neutral"]:
|
| 93 |
+
home_stats.append(new_roster[new_roster["TEAM"] == team].values.tolist())
|
| 94 |
+
|
| 95 |
+
for team in schedule["Visitor/Neutral"]:
|
| 96 |
+
away_stats.append(new_roster[new_roster["TEAM"] == team].values.tolist())
|
| 97 |
+
|
| 98 |
+
# Combine home and away stats for each game
|
| 99 |
+
for i in range(len(home_stats)):
|
| 100 |
+
arr: list[float] = []
|
| 101 |
+
|
| 102 |
+
for j in range(len(home_stats[i])):
|
| 103 |
+
del home_stats[i][j][0] # Remove team name
|
| 104 |
+
arr.extend(home_stats[i][j])
|
| 105 |
+
|
| 106 |
+
for j in range(len(away_stats[i])):
|
| 107 |
+
del away_stats[i][j][0] # Remove team name
|
| 108 |
+
arr.extend(away_stats[i][j])
|
| 109 |
+
|
| 110 |
+
# Handle NaN values
|
| 111 |
+
features.append(np.nan_to_num(np.array(arr), copy=False))
|
| 112 |
+
|
| 113 |
+
return features
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def create_model(
|
| 117 |
+
optimizer: str = "rmsprop", init: str = "glorot_uniform"
|
| 118 |
+
) -> keras.Model:
|
| 119 |
+
"""Create the neural network model architecture.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
optimizer: Optimizer name
|
| 123 |
+
init: Weight initializer name
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Compiled Keras model
|
| 127 |
+
"""
|
| 128 |
+
inputs = keras.Input(shape=(100,))
|
| 129 |
+
x = layers.Dense(50, activation="relu", kernel_initializer=init)(inputs)
|
| 130 |
+
x = layers.Dense(64, activation="relu", kernel_initializer=init)(x)
|
| 131 |
+
outputs = layers.Dense(1, activation="sigmoid")(x)
|
| 132 |
+
|
| 133 |
+
model = keras.Model(inputs=inputs, outputs=outputs, name="nba_model")
|
| 134 |
+
model.compile(
|
| 135 |
+
loss=BinaryCrossentropy(from_logits=False),
|
| 136 |
+
optimizer=optimizer,
|
| 137 |
+
metrics=["accuracy"],
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return model
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def train_model(
|
| 144 |
+
x_train: np.ndarray,
|
| 145 |
+
y_train: np.ndarray,
|
| 146 |
+
x_test: np.ndarray,
|
| 147 |
+
y_test: np.ndarray,
|
| 148 |
+
n_iterations: int = 100,
|
| 149 |
+
) -> tuple[keras.Model, dict, float]:
|
| 150 |
+
"""Train model with hyperparameter search.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
x_train: Training features
|
| 154 |
+
y_train: Training labels
|
| 155 |
+
x_test: Test features
|
| 156 |
+
y_test: Test labels
|
| 157 |
+
n_iterations: Number of random search iterations
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Tuple of (best_model, best_params, test_accuracy)
|
| 161 |
+
"""
|
| 162 |
+
model = KerasClassifier(
|
| 163 |
+
model=create_model,
|
| 164 |
+
verbose=0,
|
| 165 |
+
init="glorot_uniform",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
param_grid = {
|
| 169 |
+
"optimizer": OPTIMIZERS,
|
| 170 |
+
"epochs": EPOCHS,
|
| 171 |
+
"batch_size": BATCH_SIZES,
|
| 172 |
+
"init": INITIALIZERS,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
logger.info(f"Starting randomized search with {n_iterations} iterations")
|
| 176 |
+
|
| 177 |
+
random_search = RandomizedSearchCV(
|
| 178 |
+
estimator=model,
|
| 179 |
+
param_distributions=param_grid,
|
| 180 |
+
n_iter=n_iterations,
|
| 181 |
+
verbose=3,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
random_search_result = random_search.fit(x_train, y_train)
|
| 185 |
+
|
| 186 |
+
best_model = random_search_result.best_estimator_
|
| 187 |
+
best_params = random_search_result.best_params_
|
| 188 |
+
test_accuracy = best_model.score(x_test, y_test)
|
| 189 |
+
|
| 190 |
+
return best_model.model_, best_params, test_accuracy
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def main() -> None:
|
| 194 |
+
"""Main training pipeline."""
|
| 195 |
+
logger.info("Loading data files")
|
| 196 |
+
|
| 197 |
+
if not ROSTER_FILE.exists():
|
| 198 |
+
logger.error(f"Roster file not found: {ROSTER_FILE}")
|
| 199 |
+
raise FileNotFoundError(f"Missing {ROSTER_FILE}")
|
| 200 |
+
|
| 201 |
+
if not SCHEDULE_FILE.exists():
|
| 202 |
+
logger.error(f"Schedule file not found: {SCHEDULE_FILE}")
|
| 203 |
+
raise FileNotFoundError(f"Missing {SCHEDULE_FILE}")
|
| 204 |
+
|
| 205 |
+
roster = pd.read_csv(ROSTER_FILE, delimiter=",")
|
| 206 |
+
schedule = pd.read_csv(SCHEDULE_FILE, delimiter=",")
|
| 207 |
+
|
| 208 |
+
logger.info(f"Loaded {len(roster)} players and {len(schedule)} games")
|
| 209 |
+
|
| 210 |
+
# Create target variable: 0 = home wins, 1 = away wins
|
| 211 |
+
schedule["winner"] = schedule.apply(
|
| 212 |
+
lambda x: 0 if x["PTS"] > x["PTS.1"] else 1, axis=1
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Create feature arrays
|
| 216 |
+
logger.info("Creating feature arrays")
|
| 217 |
+
X = np.array(create_stats(roster, schedule))
|
| 218 |
+
y = np.array(schedule["winner"])
|
| 219 |
+
|
| 220 |
+
logger.info(f"Feature shape: {X.shape}, Target shape: {y.shape}")
|
| 221 |
+
|
| 222 |
+
# Split data
|
| 223 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 224 |
+
X, y, test_size=0.2, random_state=42
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
logger.info(f"Train size: {len(X_train)}, Test size: {len(X_test)}")
|
| 228 |
+
|
| 229 |
+
# Train model
|
| 230 |
+
best_model, best_params, test_accuracy = train_model(
|
| 231 |
+
X_train, y_train, X_test, y_test
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Save model
|
| 235 |
+
logger.info(f"Saving model to {OUTPUT_MODEL}")
|
| 236 |
+
best_model.save(OUTPUT_MODEL)
|
| 237 |
+
|
| 238 |
+
logger.info(f"Best parameters: {best_params}")
|
| 239 |
+
logger.info(f"Test accuracy: {test_accuracy:.4f}")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""NBA Streamlit application source package."""
|
src/config.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application configuration, constants, and logging setup."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Final
|
| 5 |
+
|
| 6 |
+
# Database column names for player data
|
| 7 |
+
PLAYER_COLUMNS: Final[list[str]] = [
|
| 8 |
+
"FULL_NAME",
|
| 9 |
+
"AST",
|
| 10 |
+
"BLK",
|
| 11 |
+
"DREB",
|
| 12 |
+
"FG3A",
|
| 13 |
+
"FG3M",
|
| 14 |
+
"FG3_PCT",
|
| 15 |
+
"FGA",
|
| 16 |
+
"FGM",
|
| 17 |
+
"FG_PCT",
|
| 18 |
+
"FTA",
|
| 19 |
+
"FTM",
|
| 20 |
+
"FT_PCT",
|
| 21 |
+
"GP",
|
| 22 |
+
"GS",
|
| 23 |
+
"MIN",
|
| 24 |
+
"OREB",
|
| 25 |
+
"PF",
|
| 26 |
+
"PTS",
|
| 27 |
+
"REB",
|
| 28 |
+
"STL",
|
| 29 |
+
"TOV",
|
| 30 |
+
"FIRST_NAME",
|
| 31 |
+
"LAST_NAME",
|
| 32 |
+
"FULL_NAME_LOWER",
|
| 33 |
+
"FIRST_NAME_LOWER",
|
| 34 |
+
"LAST_NAME_LOWER",
|
| 35 |
+
"IS_ACTIVE",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# Columns used for ML model features
|
| 39 |
+
STAT_COLUMNS: Final[list[str]] = [
|
| 40 |
+
"PTS",
|
| 41 |
+
"OREB",
|
| 42 |
+
"DREB",
|
| 43 |
+
"AST",
|
| 44 |
+
"STL",
|
| 45 |
+
"BLK",
|
| 46 |
+
"TOV",
|
| 47 |
+
"FG3_PCT",
|
| 48 |
+
"FT_PCT",
|
| 49 |
+
"FGM",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
# Game configuration
|
| 53 |
+
TEAM_SIZE: Final[int] = 5
|
| 54 |
+
MAX_QUERY_ATTEMPTS: Final[int] = 10
|
| 55 |
+
|
| 56 |
+
# Difficulty presets: (PTS, REB, AST, STL)
|
| 57 |
+
DIFFICULTY_PRESETS: Final[dict[str, tuple[int, int, int, int]]] = {
|
| 58 |
+
"Regular": (850, 400, 200, 60),
|
| 59 |
+
"93' Bulls": (1050, 500, 300, 80),
|
| 60 |
+
"All-Stars": (1250, 600, 400, 100),
|
| 61 |
+
"Dream Team": (1450, 700, 500, 120),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# Score ranges for game simulation
|
| 65 |
+
WINNER_SCORE_RANGE: Final[tuple[int, int]] = (90, 130)
|
| 66 |
+
LOSER_SCORE_RANGE: Final[tuple[int, int]] = (80, 120)
|
| 67 |
+
|
| 68 |
+
# Default fallback scores when generation fails
|
| 69 |
+
DEFAULT_WINNER_SCORE: Final[int] = 100
|
| 70 |
+
DEFAULT_LOSER_SCORE: Final[int] = 90
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def setup_logging(level: int = logging.INFO) -> logging.Logger:
|
| 74 |
+
"""Configure and return the application logger.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
level: Logging level (default: INFO)
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Configured logger instance
|
| 81 |
+
"""
|
| 82 |
+
logging.basicConfig(
|
| 83 |
+
level=level,
|
| 84 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 85 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 86 |
+
)
|
| 87 |
+
logger = logging.getLogger("streamlit_nba")
|
| 88 |
+
logger.setLevel(level)
|
| 89 |
+
return logger
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Module-level logger instance
|
| 93 |
+
logger: Final[logging.Logger] = setup_logging()
|
src/database/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database module for connection management and queries."""
|
| 2 |
+
|
| 3 |
+
from src.database.connection import (
|
| 4 |
+
get_connection,
|
| 5 |
+
DatabaseConnectionError,
|
| 6 |
+
QueryExecutionError,
|
| 7 |
+
)
|
| 8 |
+
from src.database.queries import (
|
| 9 |
+
search_player_by_name,
|
| 10 |
+
get_player_by_full_name,
|
| 11 |
+
get_players_by_full_names,
|
| 12 |
+
get_away_team_by_stats,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"get_connection",
|
| 17 |
+
"DatabaseConnectionError",
|
| 18 |
+
"QueryExecutionError",
|
| 19 |
+
"search_player_by_name",
|
| 20 |
+
"get_player_by_full_name",
|
| 21 |
+
"get_players_by_full_names",
|
| 22 |
+
"get_away_team_by_stats",
|
| 23 |
+
]
|
src/database/connection.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database connection management with error handling."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from typing import Generator
|
| 6 |
+
|
| 7 |
+
import snowflake.connector
|
| 8 |
+
import streamlit as st
|
| 9 |
+
from snowflake.connector import SnowflakeConnection
|
| 10 |
+
from snowflake.connector.errors import DatabaseError, ProgrammingError
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("streamlit_nba")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DatabaseConnectionError(Exception):
|
| 16 |
+
"""Raised when database connection fails."""
|
| 17 |
+
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class QueryExecutionError(Exception):
|
| 22 |
+
"""Raised when query execution fails."""
|
| 23 |
+
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@st.cache_resource
|
| 28 |
+
def _get_connection_pool() -> SnowflakeConnection:
|
| 29 |
+
"""Create and cache a Snowflake connection.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Cached Snowflake connection
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
DatabaseConnectionError: If connection cannot be established
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
return snowflake.connector.connect(**st.secrets["snowflake"])
|
| 39 |
+
except DatabaseError as e:
|
| 40 |
+
logger.error(f"Failed to connect to database: {e}")
|
| 41 |
+
raise DatabaseConnectionError(f"Could not connect to database: {e}") from e
|
| 42 |
+
except KeyError as e:
|
| 43 |
+
logger.error("Snowflake credentials not found in secrets")
|
| 44 |
+
raise DatabaseConnectionError(
|
| 45 |
+
"Database credentials not configured. Please check st.secrets."
|
| 46 |
+
) from e
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@contextmanager
|
| 50 |
+
def get_connection() -> Generator[SnowflakeConnection, None, None]:
|
| 51 |
+
"""Context manager for database connections with error handling.
|
| 52 |
+
|
| 53 |
+
Yields:
|
| 54 |
+
Active Snowflake connection
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
DatabaseConnectionError: If connection fails
|
| 58 |
+
|
| 59 |
+
Example:
|
| 60 |
+
with get_connection() as conn:
|
| 61 |
+
# use connection
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
conn = snowflake.connector.connect(**st.secrets["snowflake"])
|
| 65 |
+
yield conn
|
| 66 |
+
except DatabaseError as e:
|
| 67 |
+
logger.error(f"Database connection error: {e}")
|
| 68 |
+
raise DatabaseConnectionError(f"Database connection failed: {e}") from e
|
| 69 |
+
except KeyError as e:
|
| 70 |
+
logger.error("Snowflake credentials not found in secrets")
|
| 71 |
+
raise DatabaseConnectionError(
|
| 72 |
+
"Database credentials not configured. Please check st.secrets."
|
| 73 |
+
) from e
|
| 74 |
+
finally:
|
| 75 |
+
try:
|
| 76 |
+
conn.close()
|
| 77 |
+
except Exception:
|
| 78 |
+
pass # Connection may already be closed
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def execute_query(
|
| 82 |
+
conn: SnowflakeConnection,
|
| 83 |
+
query: str,
|
| 84 |
+
params: tuple | list | None = None,
|
| 85 |
+
) -> list[tuple]:
|
| 86 |
+
"""Execute a parameterized query safely.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
conn: Active database connection
|
| 90 |
+
query: SQL query with %s placeholders
|
| 91 |
+
params: Query parameters (optional)
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
List of result tuples
|
| 95 |
+
|
| 96 |
+
Raises:
|
| 97 |
+
QueryExecutionError: If query execution fails
|
| 98 |
+
"""
|
| 99 |
+
try:
|
| 100 |
+
with conn.cursor() as cur:
|
| 101 |
+
if params:
|
| 102 |
+
cur.execute(query, params)
|
| 103 |
+
else:
|
| 104 |
+
cur.execute(query)
|
| 105 |
+
return cur.fetchall()
|
| 106 |
+
except ProgrammingError as e:
|
| 107 |
+
logger.error(f"Query execution error: {e}")
|
| 108 |
+
raise QueryExecutionError(f"Query failed: {e}") from e
|
| 109 |
+
except DatabaseError as e:
|
| 110 |
+
logger.error(f"Database error during query: {e}")
|
| 111 |
+
raise QueryExecutionError(f"Database error: {e}") from e
|
src/database/queries.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Parameterized database queries for player data."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from snowflake.connector import SnowflakeConnection
|
| 8 |
+
|
| 9 |
+
from src.config import MAX_QUERY_ATTEMPTS, PLAYER_COLUMNS
|
| 10 |
+
from src.database.connection import QueryExecutionError, execute_query
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("streamlit_nba")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def search_player_by_name(conn: SnowflakeConnection, name: str) -> list[tuple[str]]:
|
| 16 |
+
"""Search for players by name (first, last, or full name).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
conn: Active database connection
|
| 20 |
+
name: Search term (case-insensitive)
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
List of tuples containing matching full names
|
| 24 |
+
"""
|
| 25 |
+
name_lower = name.lower().strip()
|
| 26 |
+
query = """
|
| 27 |
+
SELECT full_name FROM NBA
|
| 28 |
+
WHERE full_name_lower = %s
|
| 29 |
+
OR first_name_lower = %s
|
| 30 |
+
OR last_name_lower = %s
|
| 31 |
+
"""
|
| 32 |
+
return execute_query(conn, query, (name_lower, name_lower, name_lower))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_player_by_full_name(
|
| 36 |
+
conn: SnowflakeConnection, full_name: str
|
| 37 |
+
) -> tuple[Any, ...] | None:
|
| 38 |
+
"""Get a single player's full record by exact name match.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
conn: Active database connection
|
| 42 |
+
full_name: Exact full name of player
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Player data tuple or None if not found
|
| 46 |
+
"""
|
| 47 |
+
query = "SELECT * FROM NBA WHERE FULL_NAME = %s"
|
| 48 |
+
results = execute_query(conn, query, (full_name,))
|
| 49 |
+
return results[0] if results else None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_players_by_full_names(
|
| 53 |
+
conn: SnowflakeConnection, names: list[str]
|
| 54 |
+
) -> pd.DataFrame:
|
| 55 |
+
"""Get multiple players' records in a single batch query.
|
| 56 |
+
|
| 57 |
+
This fixes the N+1 query problem by using a single IN clause
|
| 58 |
+
instead of multiple individual queries.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
conn: Active database connection
|
| 62 |
+
names: List of exact full names
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
DataFrame with player data
|
| 66 |
+
"""
|
| 67 |
+
if not names:
|
| 68 |
+
return pd.DataFrame(columns=PLAYER_COLUMNS)
|
| 69 |
+
|
| 70 |
+
# Build parameterized IN clause
|
| 71 |
+
placeholders = ", ".join(["%s"] * len(names))
|
| 72 |
+
query = f"SELECT * FROM NBA WHERE FULL_NAME IN ({placeholders})"
|
| 73 |
+
|
| 74 |
+
results = execute_query(conn, query, tuple(names))
|
| 75 |
+
return pd.DataFrame(results, columns=PLAYER_COLUMNS)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_away_team_by_stats(
|
| 79 |
+
conn: SnowflakeConnection,
|
| 80 |
+
pts_threshold: int,
|
| 81 |
+
reb_threshold: int,
|
| 82 |
+
ast_threshold: int,
|
| 83 |
+
stl_threshold: int,
|
| 84 |
+
max_attempts: int = MAX_QUERY_ATTEMPTS,
|
| 85 |
+
) -> pd.DataFrame:
|
| 86 |
+
"""Get a random away team based on stat thresholds.
|
| 87 |
+
|
| 88 |
+
Uses UNION with SAMPLE to get diverse players meeting stat criteria.
|
| 89 |
+
Includes a max_attempts guard to prevent infinite loops.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
conn: Active database connection
|
| 93 |
+
pts_threshold: Minimum career points
|
| 94 |
+
reb_threshold: Minimum career rebounds
|
| 95 |
+
ast_threshold: Minimum career assists
|
| 96 |
+
stl_threshold: Minimum career steals
|
| 97 |
+
max_attempts: Maximum query attempts before raising error
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
DataFrame with 5 players
|
| 101 |
+
|
| 102 |
+
Raises:
|
| 103 |
+
QueryExecutionError: If unable to get 5 players within max_attempts
|
| 104 |
+
"""
|
| 105 |
+
query = """
|
| 106 |
+
SELECT * FROM (SELECT * FROM NBA WHERE PTS > %s) SAMPLE (2 ROWS)
|
| 107 |
+
UNION
|
| 108 |
+
SELECT * FROM (SELECT * FROM NBA WHERE REB > %s) SAMPLE (1 ROWS)
|
| 109 |
+
UNION
|
| 110 |
+
SELECT * FROM (SELECT * FROM NBA WHERE AST > %s) SAMPLE (1 ROWS)
|
| 111 |
+
UNION
|
| 112 |
+
SELECT * FROM (SELECT * FROM NBA WHERE STL > %s) SAMPLE (1 ROWS)
|
| 113 |
+
"""
|
| 114 |
+
params = (pts_threshold, reb_threshold, ast_threshold, stl_threshold)
|
| 115 |
+
|
| 116 |
+
for attempt in range(max_attempts):
|
| 117 |
+
results = execute_query(conn, query, params)
|
| 118 |
+
if len(results) == 5:
|
| 119 |
+
logger.info(f"Got away team on attempt {attempt + 1}")
|
| 120 |
+
return pd.DataFrame(results, columns=PLAYER_COLUMNS)
|
| 121 |
+
logger.debug(f"Attempt {attempt + 1}: got {len(results)} players, need 5")
|
| 122 |
+
|
| 123 |
+
# Fallback: if we can't get exactly 5, raise an error
|
| 124 |
+
raise QueryExecutionError(
|
| 125 |
+
f"Could not generate away team with 5 players after {max_attempts} attempts. "
|
| 126 |
+
f"Last attempt returned {len(results)} players."
|
| 127 |
+
)
|
src/ml/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Machine learning module for game prediction."""
|
| 2 |
+
|
| 3 |
+
from src.ml.model import (
|
| 4 |
+
ModelLoadError,
|
| 5 |
+
analyze_team_stats,
|
| 6 |
+
get_winner_model,
|
| 7 |
+
predict_winner,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"ModelLoadError",
|
| 12 |
+
"analyze_team_stats",
|
| 13 |
+
"get_winner_model",
|
| 14 |
+
"predict_winner",
|
| 15 |
+
]
|
src/ml/model.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Machine learning model loading and prediction with caching."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from tensorflow.keras.models import Model, load_model
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("streamlit_nba")
|
| 11 |
+
|
| 12 |
+
# Default model path
|
| 13 |
+
DEFAULT_MODEL_PATH = Path("winner.keras")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ModelLoadError(Exception):
|
| 17 |
+
"""Raised when model loading fails."""
|
| 18 |
+
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@st.cache_resource
|
| 23 |
+
def get_winner_model(model_path: str | Path = DEFAULT_MODEL_PATH) -> Model:
|
| 24 |
+
"""Load and cache the winner prediction model.
|
| 25 |
+
|
| 26 |
+
Uses Streamlit's cache_resource to ensure model is only loaded once
|
| 27 |
+
per session, improving performance significantly.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
model_path: Path to the Keras model file
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Loaded Keras model
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ModelLoadError: If model cannot be loaded
|
| 37 |
+
"""
|
| 38 |
+
path = Path(model_path)
|
| 39 |
+
if not path.exists():
|
| 40 |
+
logger.error(f"Model file not found: {path}")
|
| 41 |
+
raise ModelLoadError(f"Model file not found: {path}")
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
logger.info(f"Loading model from {path}")
|
| 45 |
+
model = load_model(str(path))
|
| 46 |
+
logger.info("Model loaded successfully")
|
| 47 |
+
return model
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to load model: {e}")
|
| 50 |
+
raise ModelLoadError(f"Failed to load model: {e}") from e
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def predict_winner(combined_stats: np.ndarray) -> tuple[float, int]:
|
| 54 |
+
"""Predict game winner from combined team stats.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
combined_stats: Numpy array of shape (1, 100) containing
|
| 58 |
+
home team stats followed by away team stats
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Tuple of (probability, prediction) where:
|
| 62 |
+
- probability: Float between 0-1 (sigmoid output)
|
| 63 |
+
- prediction: 0 (away wins) or 1 (home wins)
|
| 64 |
+
|
| 65 |
+
Raises:
|
| 66 |
+
ModelLoadError: If model cannot be loaded
|
| 67 |
+
ValueError: If input shape is invalid
|
| 68 |
+
"""
|
| 69 |
+
if combined_stats.shape != (1, 100):
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"Expected input shape (1, 100), got {combined_stats.shape}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
model = get_winner_model()
|
| 75 |
+
sigmoid_output = model.predict(combined_stats, verbose=0)
|
| 76 |
+
probability = float(sigmoid_output[0][0])
|
| 77 |
+
prediction = int(np.round(probability))
|
| 78 |
+
|
| 79 |
+
logger.info(f"Prediction: probability={probability:.4f}, winner={prediction}")
|
| 80 |
+
return probability, prediction
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def analyze_team_stats(
|
| 84 |
+
home_stats: list[list[float]], away_stats: list[list[float]]
|
| 85 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 86 |
+
"""Prepare team stats for model prediction.
|
| 87 |
+
|
| 88 |
+
Flattens per-player stats into team-level arrays suitable for
|
| 89 |
+
the prediction model.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
home_stats: List of stat lists for each home player
|
| 93 |
+
away_stats: List of stat lists for each away player
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Tuple of (home_array, away_array, combined_array) where:
|
| 97 |
+
- home_array: Shape (1, 50) - home team flattened stats
|
| 98 |
+
- away_array: Shape (1, 50) - away team flattened stats
|
| 99 |
+
- combined_array: Shape (1, 100) - both teams for prediction
|
| 100 |
+
"""
|
| 101 |
+
home_flat: list[float] = []
|
| 102 |
+
away_flat: list[float] = []
|
| 103 |
+
|
| 104 |
+
for player_stats in home_stats:
|
| 105 |
+
home_flat.extend(player_stats)
|
| 106 |
+
|
| 107 |
+
for player_stats in away_stats:
|
| 108 |
+
away_flat.extend(player_stats)
|
| 109 |
+
|
| 110 |
+
home_array = np.array(home_flat).reshape(1, -1)
|
| 111 |
+
away_array = np.array(away_flat).reshape(1, -1)
|
| 112 |
+
combined_array = np.array(home_flat + away_flat).reshape(1, -1)
|
| 113 |
+
|
| 114 |
+
return home_array, away_array, combined_array
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for data validation."""
|
| 2 |
+
|
| 3 |
+
from src.models.player import PlayerStats, DifficultySettings
|
| 4 |
+
|
| 5 |
+
__all__ = ["PlayerStats", "DifficultySettings"]
|
src/models/player.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for player and game data."""
|
| 2 |
+
|
| 3 |
+
from typing import ClassVar
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field, field_validator
|
| 6 |
+
|
| 7 |
+
from src.config import DIFFICULTY_PRESETS
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PlayerStats(BaseModel):
|
| 11 |
+
"""Model representing a player's career statistics."""
|
| 12 |
+
|
| 13 |
+
full_name: str = Field(..., min_length=1, max_length=100)
|
| 14 |
+
ast: int = Field(..., ge=0, description="Career assists")
|
| 15 |
+
blk: int = Field(..., ge=0, description="Career blocks")
|
| 16 |
+
dreb: int = Field(..., ge=0, description="Career defensive rebounds")
|
| 17 |
+
fg3a: int = Field(..., ge=0, description="Career 3-point attempts")
|
| 18 |
+
fg3m: int = Field(..., ge=0, description="Career 3-pointers made")
|
| 19 |
+
fg3_pct: float = Field(..., ge=0.0, le=1.0, description="3-point percentage")
|
| 20 |
+
fga: int = Field(..., ge=0, description="Career field goal attempts")
|
| 21 |
+
fgm: int = Field(..., ge=0, description="Career field goals made")
|
| 22 |
+
fg_pct: float = Field(..., ge=0.0, le=1.0, description="Field goal percentage")
|
| 23 |
+
fta: int = Field(..., ge=0, description="Career free throw attempts")
|
| 24 |
+
ftm: int = Field(..., ge=0, description="Career free throws made")
|
| 25 |
+
ft_pct: float = Field(..., ge=0.0, le=1.0, description="Free throw percentage")
|
| 26 |
+
gp: int = Field(..., ge=0, description="Games played")
|
| 27 |
+
gs: int = Field(..., ge=0, description="Games started")
|
| 28 |
+
min: int = Field(..., ge=0, description="Career minutes")
|
| 29 |
+
oreb: int = Field(..., ge=0, description="Career offensive rebounds")
|
| 30 |
+
pf: int = Field(..., ge=0, description="Career personal fouls")
|
| 31 |
+
pts: int = Field(..., ge=0, description="Career points")
|
| 32 |
+
reb: int = Field(..., ge=0, description="Career rebounds")
|
| 33 |
+
stl: int = Field(..., ge=0, description="Career steals")
|
| 34 |
+
tov: int = Field(..., ge=0, description="Career turnovers")
|
| 35 |
+
first_name: str = Field(..., max_length=50)
|
| 36 |
+
last_name: str = Field(..., max_length=50)
|
| 37 |
+
full_name_lower: str = Field(..., max_length=100)
|
| 38 |
+
first_name_lower: str = Field(..., max_length=50)
|
| 39 |
+
last_name_lower: str = Field(..., max_length=50)
|
| 40 |
+
is_active: bool = Field(default=False)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def from_db_row(cls, row: tuple) -> "PlayerStats":
|
| 44 |
+
"""Create PlayerStats from a database row tuple.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
row: Database row tuple in PLAYER_COLUMNS order
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
PlayerStats instance
|
| 51 |
+
"""
|
| 52 |
+
return cls(
|
| 53 |
+
full_name=row[0],
|
| 54 |
+
ast=row[1],
|
| 55 |
+
blk=row[2],
|
| 56 |
+
dreb=row[3],
|
| 57 |
+
fg3a=row[4],
|
| 58 |
+
fg3m=row[5],
|
| 59 |
+
fg3_pct=row[6] or 0.0,
|
| 60 |
+
fga=row[7],
|
| 61 |
+
fgm=row[8],
|
| 62 |
+
fg_pct=row[9] or 0.0,
|
| 63 |
+
fta=row[10],
|
| 64 |
+
ftm=row[11],
|
| 65 |
+
ft_pct=row[12] or 0.0,
|
| 66 |
+
gp=row[13],
|
| 67 |
+
gs=row[14],
|
| 68 |
+
min=row[15],
|
| 69 |
+
oreb=row[16],
|
| 70 |
+
pf=row[17],
|
| 71 |
+
pts=row[18],
|
| 72 |
+
reb=row[19],
|
| 73 |
+
stl=row[20],
|
| 74 |
+
tov=row[21],
|
| 75 |
+
first_name=row[22],
|
| 76 |
+
last_name=row[23],
|
| 77 |
+
full_name_lower=row[24],
|
| 78 |
+
first_name_lower=row[25],
|
| 79 |
+
last_name_lower=row[26],
|
| 80 |
+
is_active=bool(row[27]) if row[27] is not None else False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class DifficultySettings(BaseModel):
|
| 85 |
+
"""Model for game difficulty settings."""
|
| 86 |
+
|
| 87 |
+
VALID_PRESETS: ClassVar[set[str]] = set(DIFFICULTY_PRESETS.keys())
|
| 88 |
+
|
| 89 |
+
name: str = Field(..., min_length=1)
|
| 90 |
+
pts_threshold: int = Field(..., ge=0, description="Minimum career points")
|
| 91 |
+
reb_threshold: int = Field(..., ge=0, description="Minimum career rebounds")
|
| 92 |
+
ast_threshold: int = Field(..., ge=0, description="Minimum career assists")
|
| 93 |
+
stl_threshold: int = Field(..., ge=0, description="Minimum career steals")
|
| 94 |
+
|
| 95 |
+
@field_validator("name")
|
| 96 |
+
@classmethod
|
| 97 |
+
def validate_preset_name(cls, v: str) -> str:
|
| 98 |
+
"""Validate that preset name is recognized."""
|
| 99 |
+
if v not in cls.VALID_PRESETS:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"Unknown difficulty preset: {v}. "
|
| 102 |
+
f"Valid options: {', '.join(sorted(cls.VALID_PRESETS))}"
|
| 103 |
+
)
|
| 104 |
+
return v
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def from_preset(cls, preset_name: str) -> "DifficultySettings":
|
| 108 |
+
"""Create DifficultySettings from a named preset.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
preset_name: Name of difficulty preset (e.g., "Regular", "Dream Team")
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
DifficultySettings instance
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: If preset_name is not valid
|
| 118 |
+
"""
|
| 119 |
+
if preset_name not in DIFFICULTY_PRESETS:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"Unknown difficulty preset: {preset_name}. "
|
| 122 |
+
f"Valid options: {', '.join(sorted(DIFFICULTY_PRESETS.keys()))}"
|
| 123 |
+
)
|
| 124 |
+
pts, reb, ast, stl = DIFFICULTY_PRESETS[preset_name]
|
| 125 |
+
return cls(
|
| 126 |
+
name=preset_name,
|
| 127 |
+
pts_threshold=pts,
|
| 128 |
+
reb_threshold=reb,
|
| 129 |
+
ast_threshold=ast,
|
| 130 |
+
stl_threshold=stl,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def as_tuple(self) -> tuple[int, int, int, int]:
|
| 134 |
+
"""Return thresholds as tuple for backward compatibility.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Tuple of (pts, reb, ast, stl) thresholds
|
| 138 |
+
"""
|
| 139 |
+
return (
|
| 140 |
+
self.pts_threshold,
|
| 141 |
+
self.reb_threshold,
|
| 142 |
+
self.ast_threshold,
|
| 143 |
+
self.stl_threshold,
|
| 144 |
+
)
|
src/state/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session state management module."""
|
| 2 |
+
|
| 3 |
+
from src.state.session import GameState, init_session_state, get_away_stats
|
| 4 |
+
|
| 5 |
+
__all__ = ["GameState", "init_session_state", "get_away_stats"]
|
src/state/session.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session state management for the Streamlit application."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
from src.config import DIFFICULTY_PRESETS
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("streamlit_nba")
|
| 12 |
+
|
| 13 |
+
# Default difficulty preset
|
| 14 |
+
DEFAULT_DIFFICULTY = "Regular"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class GameState:
|
| 19 |
+
"""Dataclass representing the game session state."""
|
| 20 |
+
|
| 21 |
+
home_team: list[str] = field(default_factory=list)
|
| 22 |
+
away_team: list[str] = field(default_factory=list)
|
| 23 |
+
away_stats: list[int] = field(
|
| 24 |
+
default_factory=lambda: list(DIFFICULTY_PRESETS[DEFAULT_DIFFICULTY])
|
| 25 |
+
)
|
| 26 |
+
home_team_df: pd.DataFrame = field(default_factory=pd.DataFrame)
|
| 27 |
+
radio_index: int = 0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init_session_state() -> None:
|
| 31 |
+
"""Initialize all session state keys with safe defaults.
|
| 32 |
+
|
| 33 |
+
This should be called at the start of each page to ensure
|
| 34 |
+
all required state keys exist before access.
|
| 35 |
+
"""
|
| 36 |
+
defaults = {
|
| 37 |
+
"home_team": [],
|
| 38 |
+
"away_team": [],
|
| 39 |
+
"away_stats": list(DIFFICULTY_PRESETS[DEFAULT_DIFFICULTY]),
|
| 40 |
+
"home_team_df": pd.DataFrame(),
|
| 41 |
+
"radio_index": 0,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
for key, default_value in defaults.items():
|
| 45 |
+
if key not in st.session_state:
|
| 46 |
+
st.session_state[key] = default_value
|
| 47 |
+
logger.debug(f"Initialized session state: {key}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_away_stats() -> list[int]:
|
| 51 |
+
"""Safely get away team stats from session state.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of stat thresholds [pts, reb, ast, stl], or defaults if not set
|
| 55 |
+
"""
|
| 56 |
+
init_session_state() # Ensure state is initialized
|
| 57 |
+
stats = st.session_state.get("away_stats")
|
| 58 |
+
|
| 59 |
+
if stats is None or not isinstance(stats, list) or len(stats) != 4:
|
| 60 |
+
logger.warning("Invalid away_stats in session, using defaults")
|
| 61 |
+
default_stats = list(DIFFICULTY_PRESETS[DEFAULT_DIFFICULTY])
|
| 62 |
+
st.session_state["away_stats"] = default_stats
|
| 63 |
+
return default_stats
|
| 64 |
+
|
| 65 |
+
return stats
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_home_team_df() -> pd.DataFrame:
|
| 69 |
+
"""Safely get home team DataFrame from session state.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
DataFrame with home team player data, or empty DataFrame if not set
|
| 73 |
+
"""
|
| 74 |
+
init_session_state()
|
| 75 |
+
df = st.session_state.get("home_team_df")
|
| 76 |
+
|
| 77 |
+
if df is None or not isinstance(df, pd.DataFrame):
|
| 78 |
+
logger.warning("Invalid home_team_df in session, using empty DataFrame")
|
| 79 |
+
return pd.DataFrame()
|
| 80 |
+
|
| 81 |
+
return df
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_home_team_names() -> list[str]:
|
| 85 |
+
"""Safely get home team player names from session state.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
List of player names on home team
|
| 89 |
+
"""
|
| 90 |
+
init_session_state()
|
| 91 |
+
team = st.session_state.get("home_team")
|
| 92 |
+
|
| 93 |
+
if team is None or not isinstance(team, list):
|
| 94 |
+
return []
|
| 95 |
+
|
| 96 |
+
return team
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def set_difficulty(preset_name: str) -> None:
|
| 100 |
+
"""Set the difficulty level by preset name.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
preset_name: Name of difficulty preset
|
| 104 |
+
"""
|
| 105 |
+
if preset_name not in DIFFICULTY_PRESETS:
|
| 106 |
+
logger.error(f"Invalid difficulty preset: {preset_name}")
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
index = list(DIFFICULTY_PRESETS.keys()).index(preset_name)
|
| 110 |
+
st.session_state["away_stats"] = list(DIFFICULTY_PRESETS[preset_name])
|
| 111 |
+
st.session_state["radio_index"] = index
|
| 112 |
+
logger.info(f"Set difficulty to {preset_name}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def add_player_to_team(player_name: str) -> bool:
|
| 116 |
+
"""Add a player to the home team.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
player_name: Full name of player to add
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
True if added, False if already on team or team is full
|
| 123 |
+
"""
|
| 124 |
+
init_session_state()
|
| 125 |
+
team = st.session_state.get("home_team", [])
|
| 126 |
+
|
| 127 |
+
if len(team) >= 5:
|
| 128 |
+
logger.warning("Cannot add player: team is full")
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
if player_name in team:
|
| 132 |
+
logger.debug(f"Player {player_name} already on team")
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
team.append(player_name)
|
| 136 |
+
st.session_state["home_team"] = team
|
| 137 |
+
logger.info(f"Added {player_name} to team")
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def remove_player_from_team(player_name: str) -> bool:
|
| 142 |
+
"""Remove a player from the home team.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
player_name: Full name of player to remove
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
True if removed, False if not on team
|
| 149 |
+
"""
|
| 150 |
+
init_session_state()
|
| 151 |
+
team = st.session_state.get("home_team", [])
|
| 152 |
+
|
| 153 |
+
if player_name not in team:
|
| 154 |
+
logger.debug(f"Player {player_name} not on team")
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
team.remove(player_name)
|
| 158 |
+
st.session_state["home_team"] = team
|
| 159 |
+
logger.info(f"Removed {player_name} from team")
|
| 160 |
+
return True
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for HTML sanitization and other helpers."""
|
| 2 |
+
|
| 3 |
+
from src.utils.html import safe_heading, safe_paragraph
|
| 4 |
+
|
| 5 |
+
__all__ = ["safe_heading", "safe_paragraph"]
|
src/utils/html.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HTML sanitization utilities for XSS protection."""
|
| 2 |
+
|
| 3 |
+
import html
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
+
# Valid heading levels
|
| 9 |
+
HeadingLevel = Literal[1, 2, 3, 4, 5, 6]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def escape_html(text: str) -> str:
|
| 13 |
+
"""Escape HTML special characters to prevent XSS.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
text: Raw text that may contain HTML
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Escaped text safe for HTML insertion
|
| 20 |
+
"""
|
| 21 |
+
return html.escape(str(text))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def safe_heading(
|
| 25 |
+
text: str,
|
| 26 |
+
level: HeadingLevel = 1,
|
| 27 |
+
color: str = "steelblue",
|
| 28 |
+
align: str = "center",
|
| 29 |
+
) -> None:
|
| 30 |
+
"""Render a heading with escaped text to prevent XSS.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
text: Heading text (will be escaped)
|
| 34 |
+
level: Heading level 1-6
|
| 35 |
+
color: CSS color value
|
| 36 |
+
align: CSS text-align value
|
| 37 |
+
"""
|
| 38 |
+
# Escape all user-provided values
|
| 39 |
+
safe_text = escape_html(text)
|
| 40 |
+
safe_color = escape_html(color)
|
| 41 |
+
safe_align = escape_html(align)
|
| 42 |
+
|
| 43 |
+
st.markdown(
|
| 44 |
+
f"<h{level} style='text-align: {safe_align}; color: {safe_color};'>"
|
| 45 |
+
f"{safe_text}</h{level}>",
|
| 46 |
+
unsafe_allow_html=True,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def safe_paragraph(
|
| 51 |
+
text: str,
|
| 52 |
+
color: str = "white",
|
| 53 |
+
align: str = "center",
|
| 54 |
+
) -> None:
|
| 55 |
+
"""Render a paragraph with escaped text to prevent XSS.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
text: Paragraph text (will be escaped)
|
| 59 |
+
color: CSS color value
|
| 60 |
+
align: CSS text-align value
|
| 61 |
+
"""
|
| 62 |
+
safe_text = escape_html(text)
|
| 63 |
+
safe_color = escape_html(color)
|
| 64 |
+
safe_align = escape_html(align)
|
| 65 |
+
|
| 66 |
+
st.markdown(
|
| 67 |
+
f"<p style='text-align: {safe_align}; color: {safe_color};'>"
|
| 68 |
+
f"{safe_text}</p>",
|
| 69 |
+
unsafe_allow_html=True,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def safe_styled_text(
|
| 74 |
+
text: str,
|
| 75 |
+
tag: str = "span",
|
| 76 |
+
color: str | None = None,
|
| 77 |
+
align: str | None = None,
|
| 78 |
+
**styles: str,
|
| 79 |
+
) -> str:
|
| 80 |
+
"""Generate HTML string with escaped text and validated styles.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
text: Text content (will be escaped)
|
| 84 |
+
tag: HTML tag to use
|
| 85 |
+
color: Optional CSS color
|
| 86 |
+
align: Optional CSS text-align
|
| 87 |
+
**styles: Additional CSS properties
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Safe HTML string
|
| 91 |
+
"""
|
| 92 |
+
safe_text = escape_html(text)
|
| 93 |
+
safe_tag = escape_html(tag)
|
| 94 |
+
|
| 95 |
+
style_parts: list[str] = []
|
| 96 |
+
if color:
|
| 97 |
+
style_parts.append(f"color: {escape_html(color)}")
|
| 98 |
+
if align:
|
| 99 |
+
style_parts.append(f"text-align: {escape_html(align)}")
|
| 100 |
+
for prop, value in styles.items():
|
| 101 |
+
# Convert underscores to hyphens for CSS properties
|
| 102 |
+
css_prop = prop.replace("_", "-")
|
| 103 |
+
style_parts.append(f"{escape_html(css_prop)}: {escape_html(value)}")
|
| 104 |
+
|
| 105 |
+
style_str = "; ".join(style_parts)
|
| 106 |
+
if style_str:
|
| 107 |
+
return f"<{safe_tag} style='{style_str}'>{safe_text}</{safe_tag}>"
|
| 108 |
+
return f"<{safe_tag}>{safe_text}</{safe_tag}>"
|
src/validation/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Input validation module."""
|
| 2 |
+
|
| 3 |
+
from src.validation.inputs import (
|
| 4 |
+
PlayerSearchInput,
|
| 5 |
+
is_valid_search_term,
|
| 6 |
+
validate_search_term,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = ["PlayerSearchInput", "is_valid_search_term", "validate_search_term"]
|
src/validation/inputs.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Input validation for user-provided data."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field, field_validator
|
| 6 |
+
|
| 7 |
+
# Patterns that indicate SQL injection attempts
|
| 8 |
+
SQL_INJECTION_PATTERNS: list[str] = [
|
| 9 |
+
r"['\";]", # Quote characters and semicolons
|
| 10 |
+
r"--", # SQL comment
|
| 11 |
+
r"/\*", # Block comment start
|
| 12 |
+
r"\*/", # Block comment end
|
| 13 |
+
r"\bUNION\b", # UNION keyword
|
| 14 |
+
r"\bSELECT\b", # SELECT keyword
|
| 15 |
+
r"\bINSERT\b", # INSERT keyword
|
| 16 |
+
r"\bUPDATE\b", # UPDATE keyword
|
| 17 |
+
r"\bDELETE\b", # DELETE keyword
|
| 18 |
+
r"\bDROP\b", # DROP keyword
|
| 19 |
+
r"\bEXEC\b", # EXEC keyword
|
| 20 |
+
r"\bOR\s+\d+=\d+", # OR 1=1 pattern
|
| 21 |
+
r"\bAND\s+\d+=\d+", # AND 1=1 pattern
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
# Compiled regex for efficiency
|
| 25 |
+
SQL_INJECTION_REGEX = re.compile(
|
| 26 |
+
"|".join(SQL_INJECTION_PATTERNS), re.IGNORECASE
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PlayerSearchInput(BaseModel):
|
| 31 |
+
"""Validated player search input."""
|
| 32 |
+
|
| 33 |
+
search_term: str = Field(
|
| 34 |
+
...,
|
| 35 |
+
min_length=1,
|
| 36 |
+
max_length=100,
|
| 37 |
+
description="Player name search term",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@field_validator("search_term")
|
| 41 |
+
@classmethod
|
| 42 |
+
def validate_no_sql_injection(cls, v: str) -> str:
|
| 43 |
+
"""Reject inputs containing SQL injection patterns.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
v: Input search term
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Validated search term
|
| 50 |
+
|
| 51 |
+
Raises:
|
| 52 |
+
ValueError: If SQL injection pattern detected
|
| 53 |
+
"""
|
| 54 |
+
if SQL_INJECTION_REGEX.search(v):
|
| 55 |
+
raise ValueError(
|
| 56 |
+
"Invalid characters in search term. "
|
| 57 |
+
"Please use only letters, numbers, spaces, and hyphens."
|
| 58 |
+
)
|
| 59 |
+
return v.strip()
|
| 60 |
+
|
| 61 |
+
@field_validator("search_term")
|
| 62 |
+
@classmethod
|
| 63 |
+
def validate_reasonable_characters(cls, v: str) -> str:
|
| 64 |
+
"""Ensure search term contains only reasonable characters.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
v: Input search term
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Validated search term
|
| 71 |
+
|
| 72 |
+
Raises:
|
| 73 |
+
ValueError: If invalid characters found
|
| 74 |
+
"""
|
| 75 |
+
# Allow letters, numbers, spaces, hyphens, periods, and apostrophes
|
| 76 |
+
# (e.g., "O'Neal", "J.R. Smith")
|
| 77 |
+
if not re.match(r"^[a-zA-Z0-9\s\-.']+$", v):
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"Search term contains invalid characters. "
|
| 80 |
+
"Please use only letters, numbers, spaces, hyphens, "
|
| 81 |
+
"periods, and apostrophes."
|
| 82 |
+
)
|
| 83 |
+
return v
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def validate_search_term(term: str) -> str | None:
|
| 87 |
+
"""Validate a player search term.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
term: Raw search input
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Validated and cleaned search term, or None if invalid
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
validated = PlayerSearchInput(search_term=term)
|
| 97 |
+
return validated.search_term
|
| 98 |
+
except ValueError:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def is_valid_search_term(term: str) -> bool:
|
| 103 |
+
"""Check if a search term is valid without raising exceptions.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
term: Raw search input
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
True if valid, False otherwise
|
| 110 |
+
"""
|
| 111 |
+
return validate_search_term(term) is not None
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Test suite for NBA Streamlit application."""
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest fixtures for NBA Streamlit application tests."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
from unittest.mock import MagicMock
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@pytest.fixture
|
| 11 |
+
def mock_snowflake_connection() -> MagicMock:
|
| 12 |
+
"""Create a mock Snowflake connection.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Mock connection object that simulates Snowflake connection behavior
|
| 16 |
+
"""
|
| 17 |
+
mock_conn = MagicMock()
|
| 18 |
+
mock_cursor = MagicMock()
|
| 19 |
+
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
| 20 |
+
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
| 21 |
+
return mock_conn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def sample_player_data() -> list[tuple[Any, ...]]:
|
| 26 |
+
"""Create sample player data matching database schema.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
List of tuples with sample player data
|
| 30 |
+
"""
|
| 31 |
+
return [
|
| 32 |
+
(
|
| 33 |
+
"LeBron James", # FULL_NAME
|
| 34 |
+
10141, # AST
|
| 35 |
+
1107, # BLK
|
| 36 |
+
5972, # DREB
|
| 37 |
+
2891, # FG3A
|
| 38 |
+
1043, # FG3M
|
| 39 |
+
0.361, # FG3_PCT
|
| 40 |
+
24856, # FGA
|
| 41 |
+
12621, # FGM
|
| 42 |
+
0.508, # FG_PCT
|
| 43 |
+
11067, # FTA
|
| 44 |
+
7938, # FTM
|
| 45 |
+
0.717, # FT_PCT
|
| 46 |
+
1421, # GP
|
| 47 |
+
1421, # GS
|
| 48 |
+
54218, # MIN
|
| 49 |
+
1663, # OREB
|
| 50 |
+
2159, # PF
|
| 51 |
+
39223, # PTS
|
| 52 |
+
10988, # REB
|
| 53 |
+
2219, # STL
|
| 54 |
+
5015, # TOV
|
| 55 |
+
"LeBron", # FIRST_NAME
|
| 56 |
+
"James", # LAST_NAME
|
| 57 |
+
"lebron james", # FULL_NAME_LOWER
|
| 58 |
+
"lebron", # FIRST_NAME_LOWER
|
| 59 |
+
"james", # LAST_NAME_LOWER
|
| 60 |
+
True, # IS_ACTIVE
|
| 61 |
+
),
|
| 62 |
+
(
|
| 63 |
+
"Michael Jordan",
|
| 64 |
+
5633,
|
| 65 |
+
893,
|
| 66 |
+
4578,
|
| 67 |
+
1778,
|
| 68 |
+
581,
|
| 69 |
+
0.327,
|
| 70 |
+
24537,
|
| 71 |
+
12192,
|
| 72 |
+
0.497,
|
| 73 |
+
8772,
|
| 74 |
+
7327,
|
| 75 |
+
0.835,
|
| 76 |
+
1072,
|
| 77 |
+
1039,
|
| 78 |
+
41011,
|
| 79 |
+
1463,
|
| 80 |
+
2783,
|
| 81 |
+
32292,
|
| 82 |
+
6672,
|
| 83 |
+
2514,
|
| 84 |
+
2924,
|
| 85 |
+
"Michael",
|
| 86 |
+
"Jordan",
|
| 87 |
+
"michael jordan",
|
| 88 |
+
"michael",
|
| 89 |
+
"jordan",
|
| 90 |
+
False,
|
| 91 |
+
),
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@pytest.fixture
|
| 96 |
+
def sample_player_df(sample_player_data: list[tuple]) -> pd.DataFrame:
|
| 97 |
+
"""Create sample player DataFrame.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
sample_player_data: List of player tuples
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
DataFrame with sample player data
|
| 104 |
+
"""
|
| 105 |
+
from src.config import PLAYER_COLUMNS
|
| 106 |
+
|
| 107 |
+
return pd.DataFrame(sample_player_data, columns=PLAYER_COLUMNS)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@pytest.fixture
|
| 111 |
+
def sample_team_stats() -> list[list[float]]:
|
| 112 |
+
"""Create sample team stats for ML model input.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List of player stat lists (5 players x 10 stats)
|
| 116 |
+
"""
|
| 117 |
+
return [
|
| 118 |
+
[1500.0, 100.0, 200.0, 300.0, 50.0, 30.0, 100.0, 0.35, 0.80, 500.0],
|
| 119 |
+
[1200.0, 80.0, 180.0, 250.0, 40.0, 25.0, 90.0, 0.38, 0.75, 450.0],
|
| 120 |
+
[1000.0, 60.0, 150.0, 200.0, 35.0, 20.0, 80.0, 0.40, 0.82, 400.0],
|
| 121 |
+
[800.0, 50.0, 120.0, 150.0, 30.0, 15.0, 70.0, 0.33, 0.78, 350.0],
|
| 122 |
+
[600.0, 40.0, 100.0, 100.0, 25.0, 10.0, 60.0, 0.36, 0.85, 300.0],
|
| 123 |
+
]
|
tests/test_database.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for database module."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.config import MAX_QUERY_ATTEMPTS, PLAYER_COLUMNS
|
| 9 |
+
from src.database.connection import QueryExecutionError
|
| 10 |
+
from src.database.queries import (
|
| 11 |
+
get_away_team_by_stats,
|
| 12 |
+
get_player_by_full_name,
|
| 13 |
+
get_players_by_full_names,
|
| 14 |
+
search_player_by_name,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestSearchPlayerByName:
|
| 19 |
+
"""Tests for search_player_by_name function."""
|
| 20 |
+
|
| 21 |
+
def test_uses_parameterized_query(
|
| 22 |
+
self, mock_snowflake_connection: MagicMock
|
| 23 |
+
) -> None:
|
| 24 |
+
"""Verify parameterized queries are used (not string formatting)."""
|
| 25 |
+
mock_cursor = MagicMock()
|
| 26 |
+
mock_cursor.fetchall.return_value = [("LeBron James",)]
|
| 27 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 28 |
+
mock_cursor
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
search_player_by_name(mock_snowflake_connection, "james")
|
| 32 |
+
|
| 33 |
+
# Verify execute was called with params tuple, not string formatting
|
| 34 |
+
mock_cursor.execute.assert_called_once()
|
| 35 |
+
call_args = mock_cursor.execute.call_args
|
| 36 |
+
query = call_args[0][0]
|
| 37 |
+
params = call_args[0][1]
|
| 38 |
+
|
| 39 |
+
# Query should use %s placeholders
|
| 40 |
+
assert "%s" in query
|
| 41 |
+
# Should not contain the actual search term in the query string
|
| 42 |
+
assert "james" not in query.lower()
|
| 43 |
+
# Params should be a tuple with the search term
|
| 44 |
+
assert params == ("james", "james", "james")
|
| 45 |
+
|
| 46 |
+
def test_returns_list_of_tuples(
|
| 47 |
+
self, mock_snowflake_connection: MagicMock
|
| 48 |
+
) -> None:
|
| 49 |
+
"""Test that results are returned as list of tuples."""
|
| 50 |
+
mock_cursor = MagicMock()
|
| 51 |
+
mock_cursor.fetchall.return_value = [
|
| 52 |
+
("LeBron James",),
|
| 53 |
+
("James Harden",),
|
| 54 |
+
]
|
| 55 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 56 |
+
mock_cursor
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
result = search_player_by_name(mock_snowflake_connection, "james")
|
| 60 |
+
|
| 61 |
+
assert result == [("LeBron James",), ("James Harden",)]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TestGetPlayersByFullNames:
|
| 65 |
+
"""Tests for get_players_by_full_names batch query."""
|
| 66 |
+
|
| 67 |
+
def test_single_query_for_multiple_names(
|
| 68 |
+
self, mock_snowflake_connection: MagicMock, sample_player_data: list
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Verify batch query uses single IN clause instead of N queries."""
|
| 71 |
+
mock_cursor = MagicMock()
|
| 72 |
+
mock_cursor.fetchall.return_value = sample_player_data
|
| 73 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 74 |
+
mock_cursor
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
names = ["LeBron James", "Michael Jordan"]
|
| 78 |
+
get_players_by_full_names(mock_snowflake_connection, names)
|
| 79 |
+
|
| 80 |
+
# Should only execute one query
|
| 81 |
+
assert mock_cursor.execute.call_count == 1
|
| 82 |
+
|
| 83 |
+
call_args = mock_cursor.execute.call_args
|
| 84 |
+
query = call_args[0][0]
|
| 85 |
+
params = call_args[0][1]
|
| 86 |
+
|
| 87 |
+
# Query should have IN clause with placeholders
|
| 88 |
+
assert "IN" in query.upper()
|
| 89 |
+
assert "%s" in query
|
| 90 |
+
# Params should be tuple of names
|
| 91 |
+
assert params == ("LeBron James", "Michael Jordan")
|
| 92 |
+
|
| 93 |
+
def test_returns_dataframe(
|
| 94 |
+
self, mock_snowflake_connection: MagicMock, sample_player_data: list
|
| 95 |
+
) -> None:
|
| 96 |
+
"""Test that results are returned as DataFrame."""
|
| 97 |
+
mock_cursor = MagicMock()
|
| 98 |
+
mock_cursor.fetchall.return_value = sample_player_data
|
| 99 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 100 |
+
mock_cursor
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
result = get_players_by_full_names(
|
| 104 |
+
mock_snowflake_connection, ["LeBron James", "Michael Jordan"]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
assert isinstance(result, pd.DataFrame)
|
| 108 |
+
assert list(result.columns) == PLAYER_COLUMNS
|
| 109 |
+
assert len(result) == 2
|
| 110 |
+
|
| 111 |
+
def test_empty_names_returns_empty_dataframe(
|
| 112 |
+
self, mock_snowflake_connection: MagicMock
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Test that empty input returns empty DataFrame without query."""
|
| 115 |
+
mock_cursor = MagicMock()
|
| 116 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 117 |
+
mock_cursor
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
result = get_players_by_full_names(mock_snowflake_connection, [])
|
| 121 |
+
|
| 122 |
+
assert isinstance(result, pd.DataFrame)
|
| 123 |
+
assert result.empty
|
| 124 |
+
# Should not execute any query
|
| 125 |
+
mock_cursor.execute.assert_not_called()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TestGetAwayTeamByStats:
|
| 129 |
+
"""Tests for get_away_team_by_stats with max_attempts guard."""
|
| 130 |
+
|
| 131 |
+
def test_max_attempts_raises_error(
|
| 132 |
+
self, mock_snowflake_connection: MagicMock
|
| 133 |
+
) -> None:
|
| 134 |
+
"""Test that max_attempts limit prevents infinite loop."""
|
| 135 |
+
mock_cursor = MagicMock()
|
| 136 |
+
# Always return wrong number of players
|
| 137 |
+
mock_cursor.fetchall.return_value = [("Player1",), ("Player2",)]
|
| 138 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 139 |
+
mock_cursor
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
with pytest.raises(QueryExecutionError) as exc_info:
|
| 143 |
+
get_away_team_by_stats(
|
| 144 |
+
mock_snowflake_connection,
|
| 145 |
+
pts_threshold=1000,
|
| 146 |
+
reb_threshold=500,
|
| 147 |
+
ast_threshold=300,
|
| 148 |
+
stl_threshold=100,
|
| 149 |
+
max_attempts=3,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
assert "3 attempts" in str(exc_info.value)
|
| 153 |
+
assert mock_cursor.execute.call_count == 3
|
| 154 |
+
|
| 155 |
+
def test_success_on_first_try(
|
| 156 |
+
self, mock_snowflake_connection: MagicMock, sample_player_data: list
|
| 157 |
+
) -> None:
|
| 158 |
+
"""Test successful query on first attempt."""
|
| 159 |
+
mock_cursor = MagicMock()
|
| 160 |
+
# Return exactly 5 players
|
| 161 |
+
mock_cursor.fetchall.return_value = sample_player_data * 3 # 6 players
|
| 162 |
+
mock_cursor.fetchall.return_value = [
|
| 163 |
+
sample_player_data[0],
|
| 164 |
+
sample_player_data[1],
|
| 165 |
+
sample_player_data[0],
|
| 166 |
+
sample_player_data[1],
|
| 167 |
+
sample_player_data[0],
|
| 168 |
+
]
|
| 169 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 170 |
+
mock_cursor
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
result = get_away_team_by_stats(
|
| 174 |
+
mock_snowflake_connection,
|
| 175 |
+
pts_threshold=1000,
|
| 176 |
+
reb_threshold=500,
|
| 177 |
+
ast_threshold=300,
|
| 178 |
+
stl_threshold=100,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
assert isinstance(result, pd.DataFrame)
|
| 182 |
+
assert len(result) == 5
|
| 183 |
+
# Should only need one query
|
| 184 |
+
assert mock_cursor.execute.call_count == 1
|
| 185 |
+
|
| 186 |
+
def test_uses_parameterized_query(
|
| 187 |
+
self, mock_snowflake_connection: MagicMock, sample_player_data: list
|
| 188 |
+
) -> None:
|
| 189 |
+
"""Verify parameterized queries are used for stat thresholds."""
|
| 190 |
+
mock_cursor = MagicMock()
|
| 191 |
+
mock_cursor.fetchall.return_value = [
|
| 192 |
+
sample_player_data[0],
|
| 193 |
+
sample_player_data[1],
|
| 194 |
+
sample_player_data[0],
|
| 195 |
+
sample_player_data[1],
|
| 196 |
+
sample_player_data[0],
|
| 197 |
+
]
|
| 198 |
+
mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
|
| 199 |
+
mock_cursor
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
get_away_team_by_stats(
|
| 203 |
+
mock_snowflake_connection,
|
| 204 |
+
pts_threshold=1000,
|
| 205 |
+
reb_threshold=500,
|
| 206 |
+
ast_threshold=300,
|
| 207 |
+
stl_threshold=100,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
call_args = mock_cursor.execute.call_args
|
| 211 |
+
query = call_args[0][0]
|
| 212 |
+
params = call_args[0][1]
|
| 213 |
+
|
| 214 |
+
# Query should use %s placeholders
|
| 215 |
+
assert "%s" in query
|
| 216 |
+
# Should not contain actual numbers in query
|
| 217 |
+
assert "1000" not in query
|
| 218 |
+
assert "500" not in query
|
| 219 |
+
# Params should be tuple of thresholds
|
| 220 |
+
assert params == (1000, 500, 300, 100)
|
tests/test_ml.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for ML model module."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.ml.model import ModelLoadError, analyze_team_stats, predict_winner
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestAnalyzeTeamStats:
|
| 12 |
+
"""Tests for analyze_team_stats function."""
|
| 13 |
+
|
| 14 |
+
def test_flattens_stats_correctly(
|
| 15 |
+
self, sample_team_stats: list[list[float]]
|
| 16 |
+
) -> None:
|
| 17 |
+
"""Test that team stats are flattened correctly."""
|
| 18 |
+
home_array, away_array, combined = analyze_team_stats(
|
| 19 |
+
sample_team_stats, sample_team_stats
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Each team has 5 players x 10 stats = 50 values
|
| 23 |
+
assert home_array.shape == (1, 50)
|
| 24 |
+
assert away_array.shape == (1, 50)
|
| 25 |
+
# Combined has both teams = 100 values
|
| 26 |
+
assert combined.shape == (1, 100)
|
| 27 |
+
|
| 28 |
+
def test_combined_contains_both_teams(
|
| 29 |
+
self, sample_team_stats: list[list[float]]
|
| 30 |
+
) -> None:
|
| 31 |
+
"""Test that combined array contains both teams' stats."""
|
| 32 |
+
home_stats = [[1.0, 2.0], [3.0, 4.0]] # 2 players, 2 stats each
|
| 33 |
+
away_stats = [[5.0, 6.0], [7.0, 8.0]]
|
| 34 |
+
|
| 35 |
+
home_array, away_array, combined = analyze_team_stats(
|
| 36 |
+
home_stats, away_stats
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Home should be first 4 values, away next 4
|
| 40 |
+
np.testing.assert_array_equal(
|
| 41 |
+
combined[0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestPredictWinner:
|
| 46 |
+
"""Tests for predict_winner function."""
|
| 47 |
+
|
| 48 |
+
@patch("src.ml.model.get_winner_model")
|
| 49 |
+
def test_returns_probability_and_prediction(
|
| 50 |
+
self, mock_get_model: MagicMock
|
| 51 |
+
) -> None:
|
| 52 |
+
"""Test that function returns (probability, prediction) tuple."""
|
| 53 |
+
mock_model = MagicMock()
|
| 54 |
+
mock_model.predict.return_value = np.array([[0.75]])
|
| 55 |
+
mock_get_model.return_value = mock_model
|
| 56 |
+
|
| 57 |
+
stats = np.random.rand(1, 100)
|
| 58 |
+
probability, prediction = predict_winner(stats)
|
| 59 |
+
|
| 60 |
+
assert isinstance(probability, float)
|
| 61 |
+
assert isinstance(prediction, int)
|
| 62 |
+
assert 0.0 <= probability <= 1.0
|
| 63 |
+
assert prediction in (0, 1)
|
| 64 |
+
|
| 65 |
+
@patch("src.ml.model.get_winner_model")
|
| 66 |
+
def test_high_probability_predicts_win(
|
| 67 |
+
self, mock_get_model: MagicMock
|
| 68 |
+
) -> None:
|
| 69 |
+
"""Test that high probability (>0.5) predicts home win (1)."""
|
| 70 |
+
mock_model = MagicMock()
|
| 71 |
+
mock_model.predict.return_value = np.array([[0.8]])
|
| 72 |
+
mock_get_model.return_value = mock_model
|
| 73 |
+
|
| 74 |
+
stats = np.random.rand(1, 100)
|
| 75 |
+
probability, prediction = predict_winner(stats)
|
| 76 |
+
|
| 77 |
+
assert probability == 0.8
|
| 78 |
+
assert prediction == 1
|
| 79 |
+
|
| 80 |
+
@patch("src.ml.model.get_winner_model")
|
| 81 |
+
def test_low_probability_predicts_loss(
|
| 82 |
+
self, mock_get_model: MagicMock
|
| 83 |
+
) -> None:
|
| 84 |
+
"""Test that low probability (<0.5) predicts home loss (0)."""
|
| 85 |
+
mock_model = MagicMock()
|
| 86 |
+
mock_model.predict.return_value = np.array([[0.3]])
|
| 87 |
+
mock_get_model.return_value = mock_model
|
| 88 |
+
|
| 89 |
+
stats = np.random.rand(1, 100)
|
| 90 |
+
probability, prediction = predict_winner(stats)
|
| 91 |
+
|
| 92 |
+
assert probability == 0.3
|
| 93 |
+
assert prediction == 0
|
| 94 |
+
|
| 95 |
+
@patch("src.ml.model.get_winner_model")
|
| 96 |
+
def test_invalid_shape_raises_error(
|
| 97 |
+
self, mock_get_model: MagicMock
|
| 98 |
+
) -> None:
|
| 99 |
+
"""Test that invalid input shape raises ValueError."""
|
| 100 |
+
mock_model = MagicMock()
|
| 101 |
+
mock_get_model.return_value = mock_model
|
| 102 |
+
|
| 103 |
+
# Wrong shape
|
| 104 |
+
stats = np.random.rand(1, 50)
|
| 105 |
+
|
| 106 |
+
with pytest.raises(ValueError) as exc_info:
|
| 107 |
+
predict_winner(stats)
|
| 108 |
+
|
| 109 |
+
assert "Expected input shape (1, 100)" in str(exc_info.value)
|
| 110 |
+
|
| 111 |
+
@patch("src.ml.model.get_winner_model")
|
| 112 |
+
def test_model_called_with_verbose_zero(
|
| 113 |
+
self, mock_get_model: MagicMock
|
| 114 |
+
) -> None:
|
| 115 |
+
"""Test that model.predict is called with verbose=0."""
|
| 116 |
+
mock_model = MagicMock()
|
| 117 |
+
mock_model.predict.return_value = np.array([[0.5]])
|
| 118 |
+
mock_get_model.return_value = mock_model
|
| 119 |
+
|
| 120 |
+
stats = np.random.rand(1, 100)
|
| 121 |
+
predict_winner(stats)
|
| 122 |
+
|
| 123 |
+
mock_model.predict.assert_called_once_with(stats, verbose=0)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class TestGetWinnerModel:
|
| 127 |
+
"""Tests for get_winner_model caching."""
|
| 128 |
+
|
| 129 |
+
@patch("src.ml.model.load_model")
|
| 130 |
+
@patch("src.ml.model.Path")
|
| 131 |
+
def test_raises_error_for_missing_model(
|
| 132 |
+
self, mock_path: MagicMock, mock_load: MagicMock
|
| 133 |
+
) -> None:
|
| 134 |
+
"""Test that missing model file raises ModelLoadError."""
|
| 135 |
+
from src.ml.model import get_winner_model
|
| 136 |
+
|
| 137 |
+
# Clear the cache to ensure fresh test
|
| 138 |
+
get_winner_model.clear()
|
| 139 |
+
|
| 140 |
+
mock_path_instance = MagicMock()
|
| 141 |
+
mock_path_instance.exists.return_value = False
|
| 142 |
+
mock_path.return_value = mock_path_instance
|
| 143 |
+
|
| 144 |
+
with pytest.raises(ModelLoadError) as exc_info:
|
| 145 |
+
get_winner_model("nonexistent.keras")
|
| 146 |
+
|
| 147 |
+
assert "not found" in str(exc_info.value)
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Pydantic models."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from src.config import DIFFICULTY_PRESETS
|
| 6 |
+
from src.models.player import DifficultySettings, PlayerStats
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestPlayerStats:
|
| 10 |
+
"""Tests for PlayerStats model."""
|
| 11 |
+
|
| 12 |
+
def test_from_db_row(self, sample_player_data: list) -> None:
|
| 13 |
+
"""Test creating PlayerStats from database row tuple."""
|
| 14 |
+
row = sample_player_data[0] # LeBron James data
|
| 15 |
+
|
| 16 |
+
player = PlayerStats.from_db_row(row)
|
| 17 |
+
|
| 18 |
+
assert player.full_name == "LeBron James"
|
| 19 |
+
assert player.pts == 39223
|
| 20 |
+
assert player.ast == 10141
|
| 21 |
+
assert player.is_active is True
|
| 22 |
+
|
| 23 |
+
def test_validates_negative_stats(self) -> None:
|
| 24 |
+
"""Test that negative stats are rejected."""
|
| 25 |
+
with pytest.raises(ValueError):
|
| 26 |
+
PlayerStats(
|
| 27 |
+
full_name="Test Player",
|
| 28 |
+
ast=-1, # Invalid
|
| 29 |
+
blk=0,
|
| 30 |
+
dreb=0,
|
| 31 |
+
fg3a=0,
|
| 32 |
+
fg3m=0,
|
| 33 |
+
fg3_pct=0.0,
|
| 34 |
+
fga=0,
|
| 35 |
+
fgm=0,
|
| 36 |
+
fg_pct=0.0,
|
| 37 |
+
fta=0,
|
| 38 |
+
ftm=0,
|
| 39 |
+
ft_pct=0.0,
|
| 40 |
+
gp=0,
|
| 41 |
+
gs=0,
|
| 42 |
+
min=0,
|
| 43 |
+
oreb=0,
|
| 44 |
+
pf=0,
|
| 45 |
+
pts=0,
|
| 46 |
+
reb=0,
|
| 47 |
+
stl=0,
|
| 48 |
+
tov=0,
|
| 49 |
+
first_name="Test",
|
| 50 |
+
last_name="Player",
|
| 51 |
+
full_name_lower="test player",
|
| 52 |
+
first_name_lower="test",
|
| 53 |
+
last_name_lower="player",
|
| 54 |
+
is_active=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def test_validates_percentage_range(self) -> None:
|
| 58 |
+
"""Test that percentages must be 0-1."""
|
| 59 |
+
with pytest.raises(ValueError):
|
| 60 |
+
PlayerStats(
|
| 61 |
+
full_name="Test Player",
|
| 62 |
+
ast=0,
|
| 63 |
+
blk=0,
|
| 64 |
+
dreb=0,
|
| 65 |
+
fg3a=0,
|
| 66 |
+
fg3m=0,
|
| 67 |
+
fg3_pct=1.5, # Invalid - over 1.0
|
| 68 |
+
fga=0,
|
| 69 |
+
fgm=0,
|
| 70 |
+
fg_pct=0.0,
|
| 71 |
+
fta=0,
|
| 72 |
+
ftm=0,
|
| 73 |
+
ft_pct=0.0,
|
| 74 |
+
gp=0,
|
| 75 |
+
gs=0,
|
| 76 |
+
min=0,
|
| 77 |
+
oreb=0,
|
| 78 |
+
pf=0,
|
| 79 |
+
pts=0,
|
| 80 |
+
reb=0,
|
| 81 |
+
stl=0,
|
| 82 |
+
tov=0,
|
| 83 |
+
first_name="Test",
|
| 84 |
+
last_name="Player",
|
| 85 |
+
full_name_lower="test player",
|
| 86 |
+
first_name_lower="test",
|
| 87 |
+
last_name_lower="player",
|
| 88 |
+
is_active=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestDifficultySettings:
|
| 93 |
+
"""Tests for DifficultySettings model."""
|
| 94 |
+
|
| 95 |
+
@pytest.mark.parametrize("preset_name", list(DIFFICULTY_PRESETS.keys()))
|
| 96 |
+
def test_from_preset_valid(self, preset_name: str) -> None:
|
| 97 |
+
"""Test creating DifficultySettings from valid presets."""
|
| 98 |
+
settings = DifficultySettings.from_preset(preset_name)
|
| 99 |
+
|
| 100 |
+
assert settings.name == preset_name
|
| 101 |
+
expected = DIFFICULTY_PRESETS[preset_name]
|
| 102 |
+
assert settings.pts_threshold == expected[0]
|
| 103 |
+
assert settings.reb_threshold == expected[1]
|
| 104 |
+
assert settings.ast_threshold == expected[2]
|
| 105 |
+
assert settings.stl_threshold == expected[3]
|
| 106 |
+
|
| 107 |
+
def test_from_preset_invalid(self) -> None:
|
| 108 |
+
"""Test that invalid preset name raises ValueError."""
|
| 109 |
+
with pytest.raises(ValueError) as exc_info:
|
| 110 |
+
DifficultySettings.from_preset("Invalid Preset")
|
| 111 |
+
|
| 112 |
+
assert "Unknown difficulty preset" in str(exc_info.value)
|
| 113 |
+
|
| 114 |
+
def test_as_tuple(self) -> None:
|
| 115 |
+
"""Test converting settings to tuple."""
|
| 116 |
+
settings = DifficultySettings.from_preset("Regular")
|
| 117 |
+
|
| 118 |
+
result = settings.as_tuple()
|
| 119 |
+
|
| 120 |
+
assert result == DIFFICULTY_PRESETS["Regular"]
|
| 121 |
+
assert isinstance(result, tuple)
|
| 122 |
+
assert len(result) == 4
|
| 123 |
+
|
| 124 |
+
def test_regular_preset_values(self) -> None:
|
| 125 |
+
"""Test Regular preset has expected values."""
|
| 126 |
+
settings = DifficultySettings.from_preset("Regular")
|
| 127 |
+
|
| 128 |
+
assert settings.pts_threshold == 850
|
| 129 |
+
assert settings.reb_threshold == 400
|
| 130 |
+
assert settings.ast_threshold == 200
|
| 131 |
+
assert settings.stl_threshold == 60
|
| 132 |
+
|
| 133 |
+
def test_dream_team_preset_values(self) -> None:
|
| 134 |
+
"""Test Dream Team preset has highest values."""
|
| 135 |
+
settings = DifficultySettings.from_preset("Dream Team")
|
| 136 |
+
|
| 137 |
+
assert settings.pts_threshold == 1450
|
| 138 |
+
assert settings.reb_threshold == 700
|
| 139 |
+
assert settings.ast_threshold == 500
|
| 140 |
+
assert settings.stl_threshold == 120
|
tests/test_validation.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for input validation module."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from src.validation.inputs import (
|
| 6 |
+
PlayerSearchInput,
|
| 7 |
+
is_valid_search_term,
|
| 8 |
+
validate_search_term,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestPlayerSearchInput:
|
| 13 |
+
"""Tests for PlayerSearchInput validation."""
|
| 14 |
+
|
| 15 |
+
def test_valid_simple_name(self) -> None:
|
| 16 |
+
"""Test valid simple name passes validation."""
|
| 17 |
+
result = PlayerSearchInput(search_term="James")
|
| 18 |
+
assert result.search_term == "James"
|
| 19 |
+
|
| 20 |
+
def test_valid_full_name(self) -> None:
|
| 21 |
+
"""Test valid full name passes validation."""
|
| 22 |
+
result = PlayerSearchInput(search_term="LeBron James")
|
| 23 |
+
assert result.search_term == "LeBron James"
|
| 24 |
+
|
| 25 |
+
def test_valid_name_with_apostrophe(self) -> None:
|
| 26 |
+
"""Test name with apostrophe passes validation."""
|
| 27 |
+
result = PlayerSearchInput(search_term="Shaquille O'Neal")
|
| 28 |
+
assert result.search_term == "Shaquille O'Neal"
|
| 29 |
+
|
| 30 |
+
def test_valid_name_with_period(self) -> None:
|
| 31 |
+
"""Test name with period passes validation."""
|
| 32 |
+
result = PlayerSearchInput(search_term="J.R. Smith")
|
| 33 |
+
assert result.search_term == "J.R. Smith"
|
| 34 |
+
|
| 35 |
+
def test_valid_name_with_hyphen(self) -> None:
|
| 36 |
+
"""Test name with hyphen passes validation."""
|
| 37 |
+
result = PlayerSearchInput(search_term="Kareem Abdul-Jabbar")
|
| 38 |
+
assert result.search_term == "Kareem Abdul-Jabbar"
|
| 39 |
+
|
| 40 |
+
def test_strips_whitespace(self) -> None:
|
| 41 |
+
"""Test that whitespace is stripped."""
|
| 42 |
+
result = PlayerSearchInput(search_term=" James ")
|
| 43 |
+
assert result.search_term == "James"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TestSqlInjectionRejection:
|
| 47 |
+
"""Tests for SQL injection pattern rejection."""
|
| 48 |
+
|
| 49 |
+
@pytest.mark.parametrize(
|
| 50 |
+
"malicious_input",
|
| 51 |
+
[
|
| 52 |
+
"'; DROP TABLE NBA;--",
|
| 53 |
+
"James'; DELETE FROM NBA--",
|
| 54 |
+
"' OR '1'='1",
|
| 55 |
+
"James' UNION SELECT * FROM passwords--",
|
| 56 |
+
"James; SELECT * FROM users",
|
| 57 |
+
"/*comment*/James",
|
| 58 |
+
"James*/DROP TABLE/*",
|
| 59 |
+
"' OR 1=1--",
|
| 60 |
+
"James' AND 1=1--",
|
| 61 |
+
"Robert'); DROP TABLE Students;--",
|
| 62 |
+
],
|
| 63 |
+
)
|
| 64 |
+
def test_rejects_sql_injection(self, malicious_input: str) -> None:
|
| 65 |
+
"""Test that SQL injection patterns are rejected."""
|
| 66 |
+
with pytest.raises(ValueError) as exc_info:
|
| 67 |
+
PlayerSearchInput(search_term=malicious_input)
|
| 68 |
+
|
| 69 |
+
# Should mention invalid characters
|
| 70 |
+
assert "Invalid" in str(exc_info.value) or "invalid" in str(exc_info.value)
|
| 71 |
+
|
| 72 |
+
@pytest.mark.parametrize(
|
| 73 |
+
"invalid_input",
|
| 74 |
+
[
|
| 75 |
+
"James<script>",
|
| 76 |
+
"James ",
|
| 77 |
+
"James@#$%",
|
| 78 |
+
"James\\nNewline",
|
| 79 |
+
"James\x00null",
|
| 80 |
+
],
|
| 81 |
+
)
|
| 82 |
+
def test_rejects_special_characters(self, invalid_input: str) -> None:
|
| 83 |
+
"""Test that special characters are rejected."""
|
| 84 |
+
with pytest.raises(ValueError):
|
| 85 |
+
PlayerSearchInput(search_term=invalid_input)
|
| 86 |
+
|
| 87 |
+
def test_rejects_empty_string(self) -> None:
|
| 88 |
+
"""Test that empty string is rejected."""
|
| 89 |
+
with pytest.raises(ValueError):
|
| 90 |
+
PlayerSearchInput(search_term="")
|
| 91 |
+
|
| 92 |
+
def test_rejects_too_long(self) -> None:
|
| 93 |
+
"""Test that overly long input is rejected."""
|
| 94 |
+
with pytest.raises(ValueError):
|
| 95 |
+
PlayerSearchInput(search_term="A" * 101)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class TestValidateSearchTerm:
|
| 99 |
+
"""Tests for validate_search_term helper function."""
|
| 100 |
+
|
| 101 |
+
def test_returns_cleaned_term(self) -> None:
|
| 102 |
+
"""Test that valid term is returned cleaned."""
|
| 103 |
+
result = validate_search_term(" James ")
|
| 104 |
+
assert result == "James"
|
| 105 |
+
|
| 106 |
+
def test_returns_none_for_invalid(self) -> None:
|
| 107 |
+
"""Test that invalid input returns None."""
|
| 108 |
+
result = validate_search_term("'; DROP TABLE--")
|
| 109 |
+
assert result is None
|
| 110 |
+
|
| 111 |
+
def test_returns_none_for_empty(self) -> None:
|
| 112 |
+
"""Test that empty input returns None."""
|
| 113 |
+
result = validate_search_term("")
|
| 114 |
+
assert result is None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TestIsValidSearchTerm:
|
| 118 |
+
"""Tests for is_valid_search_term helper function."""
|
| 119 |
+
|
| 120 |
+
def test_returns_true_for_valid(self) -> None:
|
| 121 |
+
"""Test returns True for valid input."""
|
| 122 |
+
assert is_valid_search_term("James") is True
|
| 123 |
+
assert is_valid_search_term("LeBron James") is True
|
| 124 |
+
assert is_valid_search_term("O'Neal") is True
|
| 125 |
+
|
| 126 |
+
def test_returns_false_for_invalid(self) -> None:
|
| 127 |
+
"""Test returns False for invalid input."""
|
| 128 |
+
assert is_valid_search_term("'; DROP--") is False
|
| 129 |
+
assert is_valid_search_term("") is False
|
| 130 |
+
assert is_valid_search_term("<script>") is False
|