Hatmanstack Claude Opus 4.5 commited on
Commit
6424951
·
1 Parent(s): a63f84a

Refactor app with security fixes, error handling, and type safety

Browse files

- Fix SQL injection: use parameterized queries in src/database/queries.py
- Fix XSS: add HTML escaping in src/utils/html.py
- Re-enable CORS/XSRF protection in devcontainer.json
- Add Pydantic models for validation in src/models/player.py
- Add session state management with safe defaults
- Cache ML model loading with @st .cache_resource
- Fix N+1 query with batch IN clause
- Add loop guards to prevent infinite loops
- Add comprehensive test suite
- Add pyproject.toml with mypy strict mode and ruff config
- Move compile_model.py to scripts/ with main() guard

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

.devcontainer/devcontainer.json CHANGED
@@ -19,7 +19,7 @@
19
  },
20
  "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
21
  "postAttachCommand": {
22
- "server": "streamlit run app.py --server.enableCORS false --server.enableXsrfProtection false"
23
  },
24
  "portsAttributes": {
25
  "8501": {
 
19
  },
20
  "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
21
  "postAttachCommand": {
22
+ "server": "streamlit run app.py"
23
  },
24
  "portsAttributes": {
25
  "8501": {
app.py CHANGED
@@ -1,13 +1,23 @@
 
 
1
  import streamlit as st
2
- import pandas as pd
3
- import snowflake.connector
4
 
5
- def on_page_load():
 
 
 
 
6
  st.set_page_config(layout="wide")
 
 
7
  on_page_load()
8
 
9
- st.markdown("<h1 style='text-align: center; color: steelblue;'>NBA</h1>", unsafe_allow_html=True)
10
 
11
- st.markdown("<h5 style='text-align: center; color: white;'>A Simple app to test your skill in building a Team based on career stats to compete with a Computer</h5>", unsafe_allow_html=True)
 
 
 
 
12
 
13
 
 
1
+ """NBA Team Builder Application - Entry Point."""
2
+
3
  import streamlit as st
 
 
4
 
5
+ from src.utils.html import safe_heading, safe_paragraph
6
+
7
+
8
+ def on_page_load() -> None:
9
+ """Configure page settings."""
10
  st.set_page_config(layout="wide")
11
+
12
+
13
  on_page_load()
14
 
15
+ safe_heading("NBA", level=1, color="steelblue")
16
 
17
+ safe_paragraph(
18
+ "A Simple app to test your skill in building a Team based on "
19
+ "career stats to compete with a Computer",
20
+ color="white",
21
+ )
22
 
23
 
compile_model.py DELETED
@@ -1,73 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- from tensorflow import keras
4
- from tensorflow.keras import layers
5
- from tensorflow.keras.losses import BinaryCrossentropy
6
- from sklearn.model_selection import train_test_split
7
- from sklearn.model_selection import RandomizedSearchCV
8
- from scikeras.wrappers import KerasClassifier
9
-
10
- def create_stats(roster, schedule):
11
- home_stats = []
12
- away_stats = []
13
- S = []
14
-
15
- # Loading Relavent Columns from f-test
16
- cols = ['TEAM','PTS/G', 'ORB', 'DRB', 'AST', 'STL', 'BLK', 'TOV', '3P%', 'FT%','2P']
17
- new_roster = roster[cols]
18
- for i in schedule['Home/Neutral']:
19
- home_stats.append((new_roster[new_roster['TEAM'] == i]).values.tolist())
20
- for i in schedule['Visitor/Neutral']:
21
- away_stats.append((new_roster.loc[new_roster['TEAM'] == i]).values.tolist())
22
- for i in range(len(home_stats)):
23
- arr = []
24
- for j in range(len(home_stats[i])):
25
- del home_stats[i][j][0]
26
- arr += home_stats[i][j]
27
- for j in range(len(away_stats[i])):
28
- del away_stats[i][j][0]
29
- arr += away_stats[i][j]
30
-
31
- # Create numpy array with all the players on the Home Team's Stats followed by the Away Team's stats
32
- S.append(np.nan_to_num(np.array(arr), copy=False))
33
- return S
34
-
35
- roster = pd.read_csv('player_stats.txt', delimiter=',')
36
- schedule = pd.read_csv('schedule.txt', delimiter=',')
37
-
38
- # Create winning condition to train on
39
- schedule['winner'] = schedule.apply(lambda x: 0 if x['PTS'] > x['PTS.1'] else 1, axis=1)
40
-
41
- X = np.array(create_stats(roster, schedule))
42
- y = np.array(schedule['winner'])
43
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
44
-
45
- def create_model(optimizer='rmsprop', init='glorot_uniform'):
46
- inputs = keras.Input(shape=(100,))
47
- dense = layers.Dense(50, activation="relu")
48
- x = dense(inputs)
49
- x = layers.Dense(64, activation="relu")(x)
50
- outputs = layers.Dense(1, activation='sigmoid')(x)
51
- model = keras.Model(inputs=inputs, outputs=outputs, name="nba_model")
52
- model.compile(loss=BinaryCrossentropy(from_logits=False), optimizer=optimizer, metrics=["accuracy"])
53
-
54
- return model
55
-
56
- model = KerasClassifier(model=create_model, verbose=0, init='glorot_uniform')
57
-
58
- optimizer = ['SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam']
59
- init = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']
60
- epochs = [500, 1000, 1500]
61
- batches = [50, 100, 200]
62
- param_grid = dict(optimizer=optimizer, epochs=epochs, batch_size=batches, init=init)
63
-
64
- random_search = RandomizedSearchCV(estimator=model, param_distributions=param_grid, n_iter=100, verbose=3)
65
- random_search_result = random_search.fit(X_train, y_train)
66
- best_model = random_search_result.best_estimator_
67
-
68
- best_model.model_.save('winner.keras')
69
- best_parameters = random_search_result.best_params_
70
- print("Best parameters: ", best_parameters)
71
-
72
- test_accuracy = random_search_result.best_estimator_.score(X_test, y_test)
73
- print("Test accuracy: ", test_accuracy)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_home_team.py CHANGED
@@ -1,111 +1,155 @@
1
- import streamlit as st
 
 
 
2
  import pandas as pd
3
- import snowflake.connector
 
 
 
 
 
 
 
 
 
4
 
5
- def on_page_load():
 
 
6
  st.set_page_config(layout="wide")
 
 
7
  on_page_load()
8
 
 
 
 
9
  col1, col2, col3 = st.columns(3)
10
 
11
  with col2:
12
- st.markdown("<h1 style='text-align: center; color: steelblue;'>Build Your Team</h1>", unsafe_allow_html=True)
13
- player_add = st.text_input('Who\'re you picking?', 'James')
14
- player = player_add.lower()
15
- st.markdown("<p style='text-align: center; color: steelblue;'>Search for a player to populate the dropdown menu then pick and save your team before searching for another player.</p>", unsafe_allow_html=True)
16
- search_string = 'select full_name from NBA where full_name_lower=\'{}\' or first_name_lower=\'{}\' or last_name_lower=\'{}\';'.format(player, player, player)
17
-
18
- if 'home_team' not in st.session_state:
19
- st.session_state['home_team'] = []
20
- if 'away_team' not in st.session_state:
21
- st.session_state['away_team'] = []
22
- if 'away_stats' not in st.session_state:
23
- st.session_state['away_stats'] = []
24
- if 'home_team_df' not in st.session_state:
25
- st.session_state['home_team_df'] = pd.DataFrame()
26
- if 'radio_index' not in st.session_state:
27
- st.session_state['radio_index'] = 0
28
-
29
- def find_player():
30
- cnx = snowflake.connector.connect(**st.secrets["snowflake"])
31
- data = get_player(cnx)
32
- cnx.close()
33
- return data
34
-
35
- def get_player(cnx):
36
- with cnx.cursor() as cur:
37
- cur.execute(search_string)
38
- return cur.fetchall()
39
-
40
- player_search = find_player()
41
-
42
- def find_home_team():
43
- test =[]
44
- cnx = snowflake.connector.connect(**st.secrets["snowflake"])
45
- for i in st.session_state.home_team:
46
- with cnx.cursor() as cur:
47
- cur.execute('SELECT * FROM NBA WHERE FULL_NAME=\'{}\''.format(i))
48
- test.append(cur.fetchall()[0])
49
- cnx.close()
50
- df = pd.DataFrame(test, columns=['FULL_NAME', 'AST', 'BLK', 'DREB', 'FG3A', 'FG3M', 'FG3_PCT', 'FGA', 'FGM', 'FG_PCT', 'FTA', 'FTM', 'FT_PCT','GP', 'GS', 'MIN', 'OREB', 'PF', 'PTS', 'REB', 'STL', 'TOV', 'FIRST_NAME', 'LAST_NAME', 'FULL_NAME_LOWER', 'FIRST_NAME_LOWER', 'LAST_NAME_LOWER', 'IS_ACTIVE'])
51
- st.session_state.home_team_df = df
52
- return df
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  home_team_df = find_home_team()
55
 
56
- player_search = [player[0] for player in player_search]
57
  if not home_team_df.empty:
58
- name_list = home_team_df['FULL_NAME'].tolist()
59
- player_search += name_list
60
 
61
- def save_state():
62
- saved_players = home_team_df['FULL_NAME'].tolist()
 
 
63
  holder = saved_players + player_selected
 
64
  if len(player_selected) > len(saved_players):
65
- for i in holder:
66
- if i not in st.session_state.home_team:
67
- st.session_state.home_team.append(i)
68
  elif len(player_selected) < len(saved_players):
69
- for i in saved_players:
70
- if i not in player_selected:
71
- st.session_state.home_team.remove(i)
72
  st.rerun()
73
 
74
- col1, col2 = st.columns([7,1])
 
75
  with col1:
76
- player_selected = st.multiselect("Search Results:", player_search, home_team_df['FULL_NAME'].tolist(), label_visibility="collapsed")
 
 
 
 
 
 
77
  with col2:
78
- if st.button('Save Team'):
79
  save_state()
80
 
81
- st.markdown("<h1 style='text-align: center; color: steelblue;'>Preview</h1>", unsafe_allow_html=True)
82
-
83
  st.dataframe(home_team_df)
84
- radio_index = st.session_state.radio_index
 
85
  col1, col2, col3, col4, col5 = st.columns(5)
 
86
  with col3:
87
- st.markdown("<h3 style='text-align: center; color: steelblue;'>Away Team</h3>", unsafe_allow_html=True)
88
  difficulty = st.radio(
89
- label="Difficulty", index=radio_index, options=['Regular','93\' Bulls', 'All-Stars', 'Dream Team'],
90
- label_visibility="collapsed", )
91
-
92
- if difficulty == 'Regular':
93
- st.session_state.away_stats = [850, 400, 200, 60]
94
- st.session_state.radio_index = 0
95
- elif difficulty == '93\' Bulls':
96
- st.session_state.away_stats = [1050, 500, 300, 80]
97
- st.session_state.radio_index = 1
98
- elif difficulty == 'All-Stars':
99
- st.session_state.away_stats = [1250, 600, 400, 100]
100
- st.session_state.radio_index = 2
101
- elif difficulty == 'Dream Team':
102
- st.session_state.away_stats = [1450, 700, 500, 120]
103
- st.session_state.radio_index = 3
104
  else:
105
  st.write("You didn't select a difficulty.")
106
-
107
-
108
-
109
-
110
-
111
-
 
1
+ """Home team builder page."""
2
+
3
+ import logging
4
+
5
  import pandas as pd
6
+ import streamlit as st
7
+
8
+ from src.config import DIFFICULTY_PRESETS, PLAYER_COLUMNS
9
+ from src.database.connection import DatabaseConnectionError, QueryExecutionError, get_connection
10
+ from src.database.queries import get_players_by_full_names, search_player_by_name
11
+ from src.state.session import init_session_state
12
+ from src.utils.html import safe_heading, safe_paragraph
13
+ from src.validation.inputs import validate_search_term
14
+
15
+ logger = logging.getLogger("streamlit_nba")
16
 
17
+
18
+ def on_page_load() -> None:
19
+ """Configure page settings."""
20
  st.set_page_config(layout="wide")
21
+
22
+
23
  on_page_load()
24
 
25
+ # Initialize session state before any access
26
+ init_session_state()
27
+
28
  col1, col2, col3 = st.columns(3)
29
 
30
  with col2:
31
+ safe_heading("Build Your Team", level=1, color="steelblue")
32
+ player_add = st.text_input("Who're you picking?", "James")
33
+
34
+ safe_paragraph(
35
+ "Search for a player to populate the dropdown menu then pick and "
36
+ "save your team before searching for another player.",
37
+ color="steelblue",
38
+ )
39
+
40
+
41
+ def find_player(search_term: str) -> list[str]:
42
+ """Search for players by name with validation and error handling.
43
+
44
+ Args:
45
+ search_term: User-provided search term
46
+
47
+ Returns:
48
+ List of matching player full names
49
+ """
50
+ # Validate input
51
+ validated_term = validate_search_term(search_term)
52
+ if validated_term is None:
53
+ st.warning("Invalid search term. Please use only letters, numbers, and basic punctuation.")
54
+ return []
55
+
56
+ try:
57
+ with get_connection() as conn:
58
+ results = search_player_by_name(conn, validated_term)
59
+ return [player[0] for player in results]
60
+ except DatabaseConnectionError as e:
61
+ st.error(f"Could not connect to database. Please try again later.")
62
+ logger.error(f"Database connection error: {e}")
63
+ return []
64
+ except QueryExecutionError as e:
65
+ st.error("Error searching for players. Please try again.")
66
+ logger.error(f"Query error: {e}")
67
+ return []
68
+
69
+
70
+ def find_home_team() -> pd.DataFrame:
71
+ """Load home team data from database using batch query.
72
+
73
+ Returns:
74
+ DataFrame with home team player data
75
+ """
76
+ team_names: list[str] = st.session_state.get("home_team", [])
77
+ if not team_names:
78
+ return pd.DataFrame(columns=PLAYER_COLUMNS)
79
+
80
+ try:
81
+ with get_connection() as conn:
82
+ # Single batch query instead of N+1 queries
83
+ df = get_players_by_full_names(conn, team_names)
84
+ st.session_state.home_team_df = df
85
+ return df
86
+ except DatabaseConnectionError as e:
87
+ st.error("Could not connect to database. Please try again later.")
88
+ logger.error(f"Database connection error: {e}")
89
+ return pd.DataFrame(columns=PLAYER_COLUMNS)
90
+ except QueryExecutionError as e:
91
+ st.error("Error loading team data. Please try again.")
92
+ logger.error(f"Query error: {e}")
93
+ return pd.DataFrame(columns=PLAYER_COLUMNS)
94
+
95
+
96
+ # Load data
97
+ player_search = find_player(player_add)
98
  home_team_df = find_home_team()
99
 
100
+ # Combine search results with current team
101
  if not home_team_df.empty:
102
+ name_list = home_team_df["FULL_NAME"].tolist()
103
+ player_search = player_search + [n for n in name_list if n not in player_search]
104
 
105
+
106
+ def save_state() -> None:
107
+ """Save the selected players to session state."""
108
+ saved_players = home_team_df["FULL_NAME"].tolist() if not home_team_df.empty else []
109
  holder = saved_players + player_selected
110
+
111
  if len(player_selected) > len(saved_players):
112
+ for player in holder:
113
+ if player not in st.session_state.home_team:
114
+ st.session_state.home_team.append(player)
115
  elif len(player_selected) < len(saved_players):
116
+ for player in saved_players:
117
+ if player not in player_selected:
118
+ st.session_state.home_team.remove(player)
119
  st.rerun()
120
 
121
+
122
+ col1, col2 = st.columns([7, 1])
123
  with col1:
124
+ default_selection = home_team_df["FULL_NAME"].tolist() if not home_team_df.empty else []
125
+ player_selected = st.multiselect(
126
+ "Search Results:",
127
+ player_search,
128
+ default_selection,
129
+ label_visibility="collapsed",
130
+ )
131
  with col2:
132
+ if st.button("Save Team"):
133
  save_state()
134
 
135
+ safe_heading("Preview", level=1, color="steelblue")
136
+
137
  st.dataframe(home_team_df)
138
+
139
+ radio_index: int = st.session_state.get("radio_index", 0)
140
  col1, col2, col3, col4, col5 = st.columns(5)
141
+
142
  with col3:
143
+ safe_heading("Away Team", level=3, color="steelblue")
144
  difficulty = st.radio(
145
+ label="Difficulty",
146
+ index=radio_index,
147
+ options=list(DIFFICULTY_PRESETS.keys()),
148
+ label_visibility="collapsed",
149
+ )
150
+
151
+ if difficulty and difficulty in DIFFICULTY_PRESETS:
152
+ st.session_state.away_stats = list(DIFFICULTY_PRESETS[difficulty])
153
+ st.session_state.radio_index = list(DIFFICULTY_PRESETS.keys()).index(difficulty)
 
 
 
 
 
 
154
  else:
155
  st.write("You didn't select a difficulty.")
 
 
 
 
 
 
pages/2_play_game.py CHANGED
@@ -1,114 +1,194 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import snowflake.connector
4
- import numpy as np
5
- import tensorflow as tf
6
  import random
7
- from tensorflow.keras.models import load_model
8
 
 
 
9
 
10
- def on_page_load():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  st.set_page_config(layout="wide")
 
 
12
  on_page_load()
13
 
14
- stats = st.session_state.away_stats
 
 
 
 
15
  teams_good = True
16
- winner_prediction = 0
17
- away_point_prediction = 0
18
- home_point_prediction = 0
19
-
20
- query_string = ('SELECT * FROM (select * from NBA where PTS > {}) sample (2 rows) UNION '.format(stats[0]))
21
- query_string += ('SELECT * FROM (select * from NBA where REB > {}) sample (1 rows) UNION '.format(stats[1]))
22
- query_string += ('SELECT * FROM (select * from NBA where AST > {}) sample (1 rows) UNION '.format(stats[2]))
23
- query_string += ('SELECT * FROM (select * from NBA where STL > {}) sample (1 rows);'.format(stats[3]))
24
-
25
- def get_away_team(cnx, query_string):
26
- with cnx.cursor() as cur:
27
- cur.execute(query_string)
28
- players = cur.fetchall()
29
- while len(players) != 5:
30
- cur.execute(query_string)
31
- players = cur.fetchall()
32
- return players
33
-
34
- def find_away_team():
35
- cnx = snowflake.connector.connect(**st.secrets["snowflake"])
36
-
37
- data = get_away_team(cnx, query_string)
38
- cnx.close()
39
- df = pd.DataFrame(data, columns=['FULL_NAME', 'AST', 'BLK', 'DREB', 'FG3A', 'FG3M', 'FG3_PCT', 'FGA', 'FGM', 'FG_PCT', 'FTA', 'FTM', 'FT_PCT','GP', 'GS', 'MIN', 'OREB', 'PF', 'PTS', 'REB', 'STL', 'TOV', 'FIRST_NAME', 'LAST_NAME', 'FULL_NAME_LOWER', 'FIRST_NAME_LOWER', 'LAST_NAME_LOWER', 'IS_ACTIVE'])
40
- return df
41
-
42
- if not st.session_state.home_team_df.shape[0] == 5:
43
- st.markdown("<h3 style='text-align: center; color: red;'>Your Team Doesn't Have 5</h3>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  away_data = pd.DataFrame()
45
  teams_good = False
46
- winner = ''
 
47
  else:
48
- away_data = find_away_team()
49
-
50
- def analyze_stats(home_stats, away_stats):
51
- home=[]
52
- away=[]
53
- for j in range(len(home_stats)):
54
- home += home_stats[j]
55
- for j in range(len(away_stats)):
56
- away += away_stats[j]
57
- return np.array(home).reshape(1,-1), np.array(away).reshape(1,-1), np.array(home + away).reshape(1, -1)
58
-
59
- def get_score_board(p_pred, w_score):
60
- score = []
61
- quarter_score = int(w_score/4)
62
- score.append(quarter_score + random.randint(-7, 7))
63
- score.append(quarter_score + random.randint(-3, 3))
64
- score.append(quarter_score + random.randint(-8, 8))
65
- score.append(w_score - (score[0] + score[1] + score[2]))
66
- score.append(w_score)
67
- return score
68
-
69
- if teams_good:
70
- #first pass algo to determine winner
71
- cols = ['PTS', 'OREB', 'DREB', 'AST', 'STL', 'BLK', 'TOV', 'FG3_PCT', 'FT_PCT', 'FGM']
72
- home_stats = st.session_state.home_team_df[cols].values.tolist()
73
- away_stats = away_data[cols].values.tolist()
74
- home, away, winner = analyze_stats(home_stats, away_stats)
75
-
76
- winner_model = load_model('winner.keras')
77
-
78
- winner_sigmoid= winner_model.predict(winner)
79
- winner_prediction = np.round(winner_sigmoid[0][0])
80
-
81
- score = []
82
- winner_score = random.randint(90, 130)
83
- loser_score = random.randint(80, 120)
84
- while winner_score <= loser_score:
85
- winner_score = random.randint(90, 130)
86
- loser_score = random.randint(80, 120)
87
-
88
-
89
- if winner_prediction == 1:
90
- score.append(get_score_board(winner_prediction, winner_score))
91
- score.append(get_score_board(away_point_prediction, loser_score))
92
- winner = 'Winner'
93
- else:
94
- score.append(get_score_board(winner_prediction, loser_score))
95
- score.append(get_score_board(away_point_prediction, winner_score))
96
- winner = 'Loser'
97
-
98
- box_score = pd.DataFrame(score , columns=['1', '2', '3', '4', 'Final'], index=['Home Team', 'Away Team'] )
99
- print(f"Prediction: {winner_sigmoid[0][0]}")
100
-
101
- st.markdown("<h1 style='text-align: center; color: steelblue;'>Home Team</h1>", unsafe_allow_html=True)
102
- st.dataframe(st.session_state.home_team_df)
103
- if teams_good:
104
- print(f"Teams Good")
105
- st.markdown(f"<h3 style='text-align: center; color: steelblue;'>{winner}</h3>", unsafe_allow_html=True)
 
 
 
 
106
  col1, col2, col3 = st.columns(3)
107
  with col2:
108
- st.dataframe(box_score)
109
- st.markdown("<h1 style='text-align: center; color: steelblue;'>Away Team</h1>", unsafe_allow_html=True)
 
110
  st.dataframe(away_data)
111
 
112
  if st.button("Play New Team"):
113
- print("New Team")
114
-
 
1
+ """Game play page with prediction and scoring."""
2
+
3
+ import logging
 
 
4
  import random
 
5
 
6
+ import pandas as pd
7
+ import streamlit as st
8
 
9
+ from src.config import (
10
+ DEFAULT_LOSER_SCORE,
11
+ DEFAULT_WINNER_SCORE,
12
+ LOSER_SCORE_RANGE,
13
+ MAX_QUERY_ATTEMPTS,
14
+ STAT_COLUMNS,
15
+ TEAM_SIZE,
16
+ WINNER_SCORE_RANGE,
17
+ )
18
+ from src.database.connection import DatabaseConnectionError, QueryExecutionError, get_connection
19
+ from src.database.queries import get_away_team_by_stats
20
+ from src.ml.model import ModelLoadError, analyze_team_stats, predict_winner
21
+ from src.state.session import get_away_stats, get_home_team_df, init_session_state
22
+ from src.utils.html import safe_heading
23
+
24
+ logger = logging.getLogger("streamlit_nba")
25
+
26
+
27
+ def on_page_load() -> None:
28
+ """Configure page settings."""
29
  st.set_page_config(layout="wide")
30
+
31
+
32
  on_page_load()
33
 
34
+ # Initialize session state BEFORE any access
35
+ init_session_state()
36
+
37
+ # Get stats safely with fallback
38
+ stats = get_away_stats()
39
  teams_good = True
40
+
41
+
42
+ def find_away_team(stat_thresholds: list[int]) -> pd.DataFrame:
43
+ """Generate away team based on difficulty stats.
44
+
45
+ Args:
46
+ stat_thresholds: List of [pts, reb, ast, stl] thresholds
47
+
48
+ Returns:
49
+ DataFrame with away team data, or empty DataFrame on error
50
+ """
51
+ try:
52
+ with get_connection() as conn:
53
+ return get_away_team_by_stats(
54
+ conn,
55
+ pts_threshold=stat_thresholds[0],
56
+ reb_threshold=stat_thresholds[1],
57
+ ast_threshold=stat_thresholds[2],
58
+ stl_threshold=stat_thresholds[3],
59
+ max_attempts=MAX_QUERY_ATTEMPTS,
60
+ )
61
+ except DatabaseConnectionError as e:
62
+ st.error("Could not connect to database. Please try again later.")
63
+ logger.error(f"Database connection error: {e}")
64
+ return pd.DataFrame()
65
+ except QueryExecutionError as e:
66
+ st.error("Could not generate away team. Please try again.")
67
+ logger.error(f"Query error: {e}")
68
+ return pd.DataFrame()
69
+
70
+
71
+ def get_score_board(final_score: int) -> list[int]:
72
+ """Generate quarter-by-quarter scores that sum to final score.
73
+
74
+ Args:
75
+ final_score: Total game score
76
+
77
+ Returns:
78
+ List of [Q1, Q2, Q3, Q4, Final] scores
79
+ """
80
+ quarter_score = final_score // 4
81
+ scores = [
82
+ quarter_score + random.randint(-7, 7),
83
+ quarter_score + random.randint(-3, 3),
84
+ quarter_score + random.randint(-8, 8),
85
+ ]
86
+ # Q4 makes up the difference to hit exact final
87
+ scores.append(final_score - sum(scores))
88
+ scores.append(final_score)
89
+ return scores
90
+
91
+
92
+ def generate_game_scores() -> tuple[int, int]:
93
+ """Generate winner and loser scores with loop guard.
94
+
95
+ Returns:
96
+ Tuple of (winner_score, loser_score)
97
+ """
98
+ for _ in range(MAX_QUERY_ATTEMPTS):
99
+ winner_score = random.randint(*WINNER_SCORE_RANGE)
100
+ loser_score = random.randint(*LOSER_SCORE_RANGE)
101
+ if winner_score > loser_score:
102
+ return winner_score, loser_score
103
+
104
+ # Fallback to guaranteed valid scores
105
+ logger.warning("Score generation fell back to defaults")
106
+ return DEFAULT_WINNER_SCORE, DEFAULT_LOSER_SCORE
107
+
108
+
109
+ # Check if home team is valid
110
+ home_team_df = get_home_team_df()
111
+
112
+ if home_team_df.empty or home_team_df.shape[0] != TEAM_SIZE:
113
+ safe_heading(
114
+ f"Your Team Doesn't Have {TEAM_SIZE} Players",
115
+ level=3,
116
+ color="red",
117
+ )
118
  away_data = pd.DataFrame()
119
  teams_good = False
120
+ winner_label = ""
121
+ box_score = pd.DataFrame()
122
  else:
123
+ away_data = find_away_team(stats)
124
+ if away_data.empty:
125
+ teams_good = False
126
+ winner_label = ""
127
+ box_score = pd.DataFrame()
128
+
129
+ # Run prediction if both teams are valid
130
+ if teams_good and not away_data.empty:
131
+ try:
132
+ # Extract stats for ML model
133
+ home_stats = home_team_df[STAT_COLUMNS].values.tolist()
134
+ away_stats_data = away_data[STAT_COLUMNS].values.tolist()
135
+
136
+ # Prepare data and predict
137
+ _, _, combined = analyze_team_stats(home_stats, away_stats_data)
138
+ probability, prediction = predict_winner(combined)
139
+
140
+ # Generate scores
141
+ winner_score, loser_score = generate_game_scores()
142
+
143
+ # Build scoreboard based on prediction
144
+ if prediction == 1:
145
+ score_data = [
146
+ get_score_board(winner_score),
147
+ get_score_board(loser_score),
148
+ ]
149
+ winner_label = "Winner"
150
+ else:
151
+ score_data = [
152
+ get_score_board(loser_score),
153
+ get_score_board(winner_score),
154
+ ]
155
+ winner_label = "Loser"
156
+
157
+ box_score = pd.DataFrame(
158
+ score_data,
159
+ columns=["1", "2", "3", "4", "Final"],
160
+ index=["Home Team", "Away Team"],
161
+ )
162
+
163
+ logger.info(f"Prediction: {probability:.4f}")
164
+
165
+ except ModelLoadError as e:
166
+ st.error("Could not load prediction model. Please contact support.")
167
+ logger.error(f"Model load error: {e}")
168
+ teams_good = False
169
+ winner_label = ""
170
+ box_score = pd.DataFrame()
171
+ except ValueError as e:
172
+ st.error("Error processing team stats. Please try again.")
173
+ logger.error(f"Stats processing error: {e}")
174
+ teams_good = False
175
+ winner_label = ""
176
+ box_score = pd.DataFrame()
177
+
178
+ # Display results
179
+ safe_heading("Home Team", level=1, color="steelblue")
180
+ st.dataframe(home_team_df)
181
+
182
+ if teams_good and winner_label:
183
+ logger.info("Teams Good")
184
+ safe_heading(winner_label, level=3, color="steelblue")
185
  col1, col2, col3 = st.columns(3)
186
  with col2:
187
+ st.dataframe(box_score)
188
+
189
+ safe_heading("Away Team", level=1, color="steelblue")
190
  st.dataframe(away_data)
191
 
192
  if st.button("Play New Team"):
193
+ logger.info("New Team requested")
194
+ st.rerun()
pyproject.toml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "streamlit-nba"
3
+ version = "1.0.0"
4
+ description = "NBA team builder and game prediction Streamlit application"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "streamlit>=1.28.0",
8
+ "snowflake-connector-python>=3.5.0",
9
+ "tensorflow>=2.15.0",
10
+ "numpy>=1.24.0",
11
+ "pandas>=2.0.0",
12
+ "pydantic>=2.5.0",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ dev = [
17
+ "pytest>=7.4.0",
18
+ "pytest-cov>=4.1.0",
19
+ "mypy>=1.7.0",
20
+ "ruff>=0.1.6",
21
+ "pandas-stubs>=2.0.0",
22
+ ]
23
+
24
+ [tool.mypy]
25
+ python_version = "3.11"
26
+ strict = true
27
+ warn_return_any = true
28
+ warn_unused_configs = true
29
+ disallow_untyped_defs = true
30
+ disallow_incomplete_defs = true
31
+ check_untyped_defs = true
32
+ disallow_untyped_decorators = true
33
+ no_implicit_optional = true
34
+ warn_redundant_casts = true
35
+ warn_unused_ignores = true
36
+ warn_no_return = true
37
+ warn_unreachable = true
38
+
39
+ # Third-party library ignores
40
+ [[tool.mypy.overrides]]
41
+ module = [
42
+ "streamlit.*",
43
+ "snowflake.*",
44
+ "tensorflow.*",
45
+ "keras.*",
46
+ "sklearn.*",
47
+ "scikeras.*",
48
+ ]
49
+ ignore_missing_imports = true
50
+
51
+ [tool.ruff]
52
+ target-version = "py311"
53
+ line-length = 88
54
+
55
+ [tool.ruff.lint]
56
+ select = [
57
+ "E", # pycodestyle errors
58
+ "W", # pycodestyle warnings
59
+ "F", # pyflakes
60
+ "I", # isort
61
+ "B", # flake8-bugbear
62
+ "C4", # flake8-comprehensions
63
+ "UP", # pyupgrade
64
+ "ARG", # flake8-unused-arguments
65
+ "SIM", # flake8-simplify
66
+ "TCH", # flake8-type-checking
67
+ "PTH", # flake8-use-pathlib
68
+ "PL", # pylint
69
+ "RUF", # ruff-specific
70
+ "S", # flake8-bandit (security)
71
+ ]
72
+ ignore = [
73
+ "S101", # assert used (ok in tests)
74
+ "PLR0913", # too many arguments
75
+ ]
76
+
77
+ [tool.ruff.lint.per-file-ignores]
78
+ "tests/*" = ["S101", "ARG001", "PLR2004"]
79
+
80
+ [tool.pytest.ini_options]
81
+ testpaths = ["tests"]
82
+ python_files = ["test_*.py"]
83
+ python_functions = ["test_*"]
84
+ addopts = [
85
+ "-v",
86
+ "--tb=short",
87
+ "--strict-markers",
88
+ ]
89
+
90
+ [tool.coverage.run]
91
+ source = ["src"]
92
+ branch = true
93
+ omit = [
94
+ "tests/*",
95
+ "scripts/*",
96
+ ]
97
+
98
+ [tool.coverage.report]
99
+ exclude_lines = [
100
+ "pragma: no cover",
101
+ "if TYPE_CHECKING:",
102
+ "raise NotImplementedError",
103
+ ]
104
+ fail_under = 80
requirements-dev.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ -r requirements.txt
2
+ pytest>=7.4.0
3
+ pytest-cov>=4.1.0
4
+ mypy>=1.7.0
5
+ ruff>=0.1.6
6
+ pandas-stubs>=2.0.0
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- snowflake-connector-python
2
- tensorflow
3
- numpy
4
- pandas
 
 
 
1
+ streamlit>=1.28.0
2
+ snowflake-connector-python>=3.5.0
3
+ tensorflow>=2.15.0
4
+ numpy>=1.24.0
5
+ pandas>=2.0.0
6
+ pydantic>=2.5.0
scripts/compile_model.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """NBA game winner prediction model training script.
3
+
4
+ This script trains a neural network to predict game winners based on
5
+ team statistics. It uses RandomizedSearchCV to find optimal hyperparameters.
6
+
7
+ Usage:
8
+ python scripts/compile_model.py
9
+ """
10
+
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ from scikeras.wrappers import KerasClassifier
17
+ from sklearn.model_selection import RandomizedSearchCV, train_test_split
18
+ from tensorflow import keras
19
+ from tensorflow.keras import layers
20
+ from tensorflow.keras.losses import BinaryCrossentropy
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s - %(levelname)s - %(message)s",
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Data file paths
30
+ ROSTER_FILE = Path("player_stats.txt")
31
+ SCHEDULE_FILE = Path("schedule.txt")
32
+ OUTPUT_MODEL = Path("winner.keras")
33
+
34
+ # Feature columns from roster data
35
+ FEATURE_COLS: list[str] = [
36
+ "TEAM",
37
+ "PTS/G",
38
+ "ORB",
39
+ "DRB",
40
+ "AST",
41
+ "STL",
42
+ "BLK",
43
+ "TOV",
44
+ "3P%",
45
+ "FT%",
46
+ "2P",
47
+ ]
48
+
49
+ # Hyperparameter search space
50
+ OPTIMIZERS: list[str] = [
51
+ "SGD",
52
+ "RMSprop",
53
+ "Adagrad",
54
+ "Adadelta",
55
+ "Adam",
56
+ "Adamax",
57
+ "Nadam",
58
+ ]
59
+ INITIALIZERS: list[str] = [
60
+ "uniform",
61
+ "lecun_uniform",
62
+ "normal",
63
+ "zero",
64
+ "glorot_normal",
65
+ "glorot_uniform",
66
+ "he_normal",
67
+ "he_uniform",
68
+ ]
69
+ EPOCHS: list[int] = [500, 1000, 1500]
70
+ BATCH_SIZES: list[int] = [50, 100, 200]
71
+
72
+
73
+ def create_stats(
74
+ roster: pd.DataFrame, schedule: pd.DataFrame
75
+ ) -> list[np.ndarray]:
76
+ """Create feature arrays from roster and schedule data.
77
+
78
+ Args:
79
+ roster: DataFrame with player statistics
80
+ schedule: DataFrame with game schedule and scores
81
+
82
+ Returns:
83
+ List of numpy arrays, one per game with combined team stats
84
+ """
85
+ home_stats: list[list] = []
86
+ away_stats: list[list] = []
87
+ features: list[np.ndarray] = []
88
+
89
+ new_roster = roster[FEATURE_COLS]
90
+
91
+ # Get stats for each team in each game
92
+ for team in schedule["Home/Neutral"]:
93
+ home_stats.append(new_roster[new_roster["TEAM"] == team].values.tolist())
94
+
95
+ for team in schedule["Visitor/Neutral"]:
96
+ away_stats.append(new_roster[new_roster["TEAM"] == team].values.tolist())
97
+
98
+ # Combine home and away stats for each game
99
+ for i in range(len(home_stats)):
100
+ arr: list[float] = []
101
+
102
+ for j in range(len(home_stats[i])):
103
+ del home_stats[i][j][0] # Remove team name
104
+ arr.extend(home_stats[i][j])
105
+
106
+ for j in range(len(away_stats[i])):
107
+ del away_stats[i][j][0] # Remove team name
108
+ arr.extend(away_stats[i][j])
109
+
110
+ # Handle NaN values
111
+ features.append(np.nan_to_num(np.array(arr), copy=False))
112
+
113
+ return features
114
+
115
+
116
+ def create_model(
117
+ optimizer: str = "rmsprop", init: str = "glorot_uniform"
118
+ ) -> keras.Model:
119
+ """Create the neural network model architecture.
120
+
121
+ Args:
122
+ optimizer: Optimizer name
123
+ init: Weight initializer name
124
+
125
+ Returns:
126
+ Compiled Keras model
127
+ """
128
+ inputs = keras.Input(shape=(100,))
129
+ x = layers.Dense(50, activation="relu", kernel_initializer=init)(inputs)
130
+ x = layers.Dense(64, activation="relu", kernel_initializer=init)(x)
131
+ outputs = layers.Dense(1, activation="sigmoid")(x)
132
+
133
+ model = keras.Model(inputs=inputs, outputs=outputs, name="nba_model")
134
+ model.compile(
135
+ loss=BinaryCrossentropy(from_logits=False),
136
+ optimizer=optimizer,
137
+ metrics=["accuracy"],
138
+ )
139
+
140
+ return model
141
+
142
+
143
+ def train_model(
144
+ x_train: np.ndarray,
145
+ y_train: np.ndarray,
146
+ x_test: np.ndarray,
147
+ y_test: np.ndarray,
148
+ n_iterations: int = 100,
149
+ ) -> tuple[keras.Model, dict, float]:
150
+ """Train model with hyperparameter search.
151
+
152
+ Args:
153
+ x_train: Training features
154
+ y_train: Training labels
155
+ x_test: Test features
156
+ y_test: Test labels
157
+ n_iterations: Number of random search iterations
158
+
159
+ Returns:
160
+ Tuple of (best_model, best_params, test_accuracy)
161
+ """
162
+ model = KerasClassifier(
163
+ model=create_model,
164
+ verbose=0,
165
+ init="glorot_uniform",
166
+ )
167
+
168
+ param_grid = {
169
+ "optimizer": OPTIMIZERS,
170
+ "epochs": EPOCHS,
171
+ "batch_size": BATCH_SIZES,
172
+ "init": INITIALIZERS,
173
+ }
174
+
175
+ logger.info(f"Starting randomized search with {n_iterations} iterations")
176
+
177
+ random_search = RandomizedSearchCV(
178
+ estimator=model,
179
+ param_distributions=param_grid,
180
+ n_iter=n_iterations,
181
+ verbose=3,
182
+ )
183
+
184
+ random_search_result = random_search.fit(x_train, y_train)
185
+
186
+ best_model = random_search_result.best_estimator_
187
+ best_params = random_search_result.best_params_
188
+ test_accuracy = best_model.score(x_test, y_test)
189
+
190
+ return best_model.model_, best_params, test_accuracy
191
+
192
+
193
+ def main() -> None:
194
+ """Main training pipeline."""
195
+ logger.info("Loading data files")
196
+
197
+ if not ROSTER_FILE.exists():
198
+ logger.error(f"Roster file not found: {ROSTER_FILE}")
199
+ raise FileNotFoundError(f"Missing {ROSTER_FILE}")
200
+
201
+ if not SCHEDULE_FILE.exists():
202
+ logger.error(f"Schedule file not found: {SCHEDULE_FILE}")
203
+ raise FileNotFoundError(f"Missing {SCHEDULE_FILE}")
204
+
205
+ roster = pd.read_csv(ROSTER_FILE, delimiter=",")
206
+ schedule = pd.read_csv(SCHEDULE_FILE, delimiter=",")
207
+
208
+ logger.info(f"Loaded {len(roster)} players and {len(schedule)} games")
209
+
210
+ # Create target variable: 0 = home wins, 1 = away wins
211
+ schedule["winner"] = schedule.apply(
212
+ lambda x: 0 if x["PTS"] > x["PTS.1"] else 1, axis=1
213
+ )
214
+
215
+ # Create feature arrays
216
+ logger.info("Creating feature arrays")
217
+ X = np.array(create_stats(roster, schedule))
218
+ y = np.array(schedule["winner"])
219
+
220
+ logger.info(f"Feature shape: {X.shape}, Target shape: {y.shape}")
221
+
222
+ # Split data
223
+ X_train, X_test, y_train, y_test = train_test_split(
224
+ X, y, test_size=0.2, random_state=42
225
+ )
226
+
227
+ logger.info(f"Train size: {len(X_train)}, Test size: {len(X_test)}")
228
+
229
+ # Train model
230
+ best_model, best_params, test_accuracy = train_model(
231
+ X_train, y_train, X_test, y_test
232
+ )
233
+
234
+ # Save model
235
+ logger.info(f"Saving model to {OUTPUT_MODEL}")
236
+ best_model.save(OUTPUT_MODEL)
237
+
238
+ logger.info(f"Best parameters: {best_params}")
239
+ logger.info(f"Test accuracy: {test_accuracy:.4f}")
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """NBA Streamlit application source package."""
src/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application configuration, constants, and logging setup."""
2
+
3
+ import logging
4
+ from typing import Final
5
+
6
+ # Database column names for player data
7
+ PLAYER_COLUMNS: Final[list[str]] = [
8
+ "FULL_NAME",
9
+ "AST",
10
+ "BLK",
11
+ "DREB",
12
+ "FG3A",
13
+ "FG3M",
14
+ "FG3_PCT",
15
+ "FGA",
16
+ "FGM",
17
+ "FG_PCT",
18
+ "FTA",
19
+ "FTM",
20
+ "FT_PCT",
21
+ "GP",
22
+ "GS",
23
+ "MIN",
24
+ "OREB",
25
+ "PF",
26
+ "PTS",
27
+ "REB",
28
+ "STL",
29
+ "TOV",
30
+ "FIRST_NAME",
31
+ "LAST_NAME",
32
+ "FULL_NAME_LOWER",
33
+ "FIRST_NAME_LOWER",
34
+ "LAST_NAME_LOWER",
35
+ "IS_ACTIVE",
36
+ ]
37
+
38
+ # Columns used for ML model features
39
+ STAT_COLUMNS: Final[list[str]] = [
40
+ "PTS",
41
+ "OREB",
42
+ "DREB",
43
+ "AST",
44
+ "STL",
45
+ "BLK",
46
+ "TOV",
47
+ "FG3_PCT",
48
+ "FT_PCT",
49
+ "FGM",
50
+ ]
51
+
52
+ # Game configuration
53
+ TEAM_SIZE: Final[int] = 5
54
+ MAX_QUERY_ATTEMPTS: Final[int] = 10
55
+
56
+ # Difficulty presets: (PTS, REB, AST, STL)
57
+ DIFFICULTY_PRESETS: Final[dict[str, tuple[int, int, int, int]]] = {
58
+ "Regular": (850, 400, 200, 60),
59
+ "93' Bulls": (1050, 500, 300, 80),
60
+ "All-Stars": (1250, 600, 400, 100),
61
+ "Dream Team": (1450, 700, 500, 120),
62
+ }
63
+
64
+ # Score ranges for game simulation
65
+ WINNER_SCORE_RANGE: Final[tuple[int, int]] = (90, 130)
66
+ LOSER_SCORE_RANGE: Final[tuple[int, int]] = (80, 120)
67
+
68
+ # Default fallback scores when generation fails
69
+ DEFAULT_WINNER_SCORE: Final[int] = 100
70
+ DEFAULT_LOSER_SCORE: Final[int] = 90
71
+
72
+
73
+ def setup_logging(level: int = logging.INFO) -> logging.Logger:
74
+ """Configure and return the application logger.
75
+
76
+ Args:
77
+ level: Logging level (default: INFO)
78
+
79
+ Returns:
80
+ Configured logger instance
81
+ """
82
+ logging.basicConfig(
83
+ level=level,
84
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
85
+ datefmt="%Y-%m-%d %H:%M:%S",
86
+ )
87
+ logger = logging.getLogger("streamlit_nba")
88
+ logger.setLevel(level)
89
+ return logger
90
+
91
+
92
+ # Module-level logger instance
93
+ logger: Final[logging.Logger] = setup_logging()
src/database/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Database module for connection management and queries."""
2
+
3
+ from src.database.connection import (
4
+ get_connection,
5
+ DatabaseConnectionError,
6
+ QueryExecutionError,
7
+ )
8
+ from src.database.queries import (
9
+ search_player_by_name,
10
+ get_player_by_full_name,
11
+ get_players_by_full_names,
12
+ get_away_team_by_stats,
13
+ )
14
+
15
+ __all__ = [
16
+ "get_connection",
17
+ "DatabaseConnectionError",
18
+ "QueryExecutionError",
19
+ "search_player_by_name",
20
+ "get_player_by_full_name",
21
+ "get_players_by_full_names",
22
+ "get_away_team_by_stats",
23
+ ]
src/database/connection.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Database connection management with error handling."""
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Generator
6
+
7
+ import snowflake.connector
8
+ import streamlit as st
9
+ from snowflake.connector import SnowflakeConnection
10
+ from snowflake.connector.errors import DatabaseError, ProgrammingError
11
+
12
+ logger = logging.getLogger("streamlit_nba")
13
+
14
+
15
+ class DatabaseConnectionError(Exception):
16
+ """Raised when database connection fails."""
17
+
18
+ pass
19
+
20
+
21
+ class QueryExecutionError(Exception):
22
+ """Raised when query execution fails."""
23
+
24
+ pass
25
+
26
+
27
+ @st.cache_resource
28
+ def _get_connection_pool() -> SnowflakeConnection:
29
+ """Create and cache a Snowflake connection.
30
+
31
+ Returns:
32
+ Cached Snowflake connection
33
+
34
+ Raises:
35
+ DatabaseConnectionError: If connection cannot be established
36
+ """
37
+ try:
38
+ return snowflake.connector.connect(**st.secrets["snowflake"])
39
+ except DatabaseError as e:
40
+ logger.error(f"Failed to connect to database: {e}")
41
+ raise DatabaseConnectionError(f"Could not connect to database: {e}") from e
42
+ except KeyError as e:
43
+ logger.error("Snowflake credentials not found in secrets")
44
+ raise DatabaseConnectionError(
45
+ "Database credentials not configured. Please check st.secrets."
46
+ ) from e
47
+
48
+
49
+ @contextmanager
50
+ def get_connection() -> Generator[SnowflakeConnection, None, None]:
51
+ """Context manager for database connections with error handling.
52
+
53
+ Yields:
54
+ Active Snowflake connection
55
+
56
+ Raises:
57
+ DatabaseConnectionError: If connection fails
58
+
59
+ Example:
60
+ with get_connection() as conn:
61
+ # use connection
62
+ """
63
+ try:
64
+ conn = snowflake.connector.connect(**st.secrets["snowflake"])
65
+ yield conn
66
+ except DatabaseError as e:
67
+ logger.error(f"Database connection error: {e}")
68
+ raise DatabaseConnectionError(f"Database connection failed: {e}") from e
69
+ except KeyError as e:
70
+ logger.error("Snowflake credentials not found in secrets")
71
+ raise DatabaseConnectionError(
72
+ "Database credentials not configured. Please check st.secrets."
73
+ ) from e
74
+ finally:
75
+ try:
76
+ conn.close()
77
+ except Exception:
78
+ pass # Connection may already be closed
79
+
80
+
81
+ def execute_query(
82
+ conn: SnowflakeConnection,
83
+ query: str,
84
+ params: tuple | list | None = None,
85
+ ) -> list[tuple]:
86
+ """Execute a parameterized query safely.
87
+
88
+ Args:
89
+ conn: Active database connection
90
+ query: SQL query with %s placeholders
91
+ params: Query parameters (optional)
92
+
93
+ Returns:
94
+ List of result tuples
95
+
96
+ Raises:
97
+ QueryExecutionError: If query execution fails
98
+ """
99
+ try:
100
+ with conn.cursor() as cur:
101
+ if params:
102
+ cur.execute(query, params)
103
+ else:
104
+ cur.execute(query)
105
+ return cur.fetchall()
106
+ except ProgrammingError as e:
107
+ logger.error(f"Query execution error: {e}")
108
+ raise QueryExecutionError(f"Query failed: {e}") from e
109
+ except DatabaseError as e:
110
+ logger.error(f"Database error during query: {e}")
111
+ raise QueryExecutionError(f"Database error: {e}") from e
src/database/queries.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parameterized database queries for player data."""
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ import pandas as pd
7
+ from snowflake.connector import SnowflakeConnection
8
+
9
+ from src.config import MAX_QUERY_ATTEMPTS, PLAYER_COLUMNS
10
+ from src.database.connection import QueryExecutionError, execute_query
11
+
12
+ logger = logging.getLogger("streamlit_nba")
13
+
14
+
15
+ def search_player_by_name(conn: SnowflakeConnection, name: str) -> list[tuple[str]]:
16
+ """Search for players by name (first, last, or full name).
17
+
18
+ Args:
19
+ conn: Active database connection
20
+ name: Search term (case-insensitive)
21
+
22
+ Returns:
23
+ List of tuples containing matching full names
24
+ """
25
+ name_lower = name.lower().strip()
26
+ query = """
27
+ SELECT full_name FROM NBA
28
+ WHERE full_name_lower = %s
29
+ OR first_name_lower = %s
30
+ OR last_name_lower = %s
31
+ """
32
+ return execute_query(conn, query, (name_lower, name_lower, name_lower))
33
+
34
+
35
+ def get_player_by_full_name(
36
+ conn: SnowflakeConnection, full_name: str
37
+ ) -> tuple[Any, ...] | None:
38
+ """Get a single player's full record by exact name match.
39
+
40
+ Args:
41
+ conn: Active database connection
42
+ full_name: Exact full name of player
43
+
44
+ Returns:
45
+ Player data tuple or None if not found
46
+ """
47
+ query = "SELECT * FROM NBA WHERE FULL_NAME = %s"
48
+ results = execute_query(conn, query, (full_name,))
49
+ return results[0] if results else None
50
+
51
+
52
+ def get_players_by_full_names(
53
+ conn: SnowflakeConnection, names: list[str]
54
+ ) -> pd.DataFrame:
55
+ """Get multiple players' records in a single batch query.
56
+
57
+ This fixes the N+1 query problem by using a single IN clause
58
+ instead of multiple individual queries.
59
+
60
+ Args:
61
+ conn: Active database connection
62
+ names: List of exact full names
63
+
64
+ Returns:
65
+ DataFrame with player data
66
+ """
67
+ if not names:
68
+ return pd.DataFrame(columns=PLAYER_COLUMNS)
69
+
70
+ # Build parameterized IN clause
71
+ placeholders = ", ".join(["%s"] * len(names))
72
+ query = f"SELECT * FROM NBA WHERE FULL_NAME IN ({placeholders})"
73
+
74
+ results = execute_query(conn, query, tuple(names))
75
+ return pd.DataFrame(results, columns=PLAYER_COLUMNS)
76
+
77
+
78
+ def get_away_team_by_stats(
79
+ conn: SnowflakeConnection,
80
+ pts_threshold: int,
81
+ reb_threshold: int,
82
+ ast_threshold: int,
83
+ stl_threshold: int,
84
+ max_attempts: int = MAX_QUERY_ATTEMPTS,
85
+ ) -> pd.DataFrame:
86
+ """Get a random away team based on stat thresholds.
87
+
88
+ Uses UNION with SAMPLE to get diverse players meeting stat criteria.
89
+ Includes a max_attempts guard to prevent infinite loops.
90
+
91
+ Args:
92
+ conn: Active database connection
93
+ pts_threshold: Minimum career points
94
+ reb_threshold: Minimum career rebounds
95
+ ast_threshold: Minimum career assists
96
+ stl_threshold: Minimum career steals
97
+ max_attempts: Maximum query attempts before raising error
98
+
99
+ Returns:
100
+ DataFrame with 5 players
101
+
102
+ Raises:
103
+ QueryExecutionError: If unable to get 5 players within max_attempts
104
+ """
105
+ query = """
106
+ SELECT * FROM (SELECT * FROM NBA WHERE PTS > %s) SAMPLE (2 ROWS)
107
+ UNION
108
+ SELECT * FROM (SELECT * FROM NBA WHERE REB > %s) SAMPLE (1 ROWS)
109
+ UNION
110
+ SELECT * FROM (SELECT * FROM NBA WHERE AST > %s) SAMPLE (1 ROWS)
111
+ UNION
112
+ SELECT * FROM (SELECT * FROM NBA WHERE STL > %s) SAMPLE (1 ROWS)
113
+ """
114
+ params = (pts_threshold, reb_threshold, ast_threshold, stl_threshold)
115
+
116
+ for attempt in range(max_attempts):
117
+ results = execute_query(conn, query, params)
118
+ if len(results) == 5:
119
+ logger.info(f"Got away team on attempt {attempt + 1}")
120
+ return pd.DataFrame(results, columns=PLAYER_COLUMNS)
121
+ logger.debug(f"Attempt {attempt + 1}: got {len(results)} players, need 5")
122
+
123
+ # Fallback: if we can't get exactly 5, raise an error
124
+ raise QueryExecutionError(
125
+ f"Could not generate away team with 5 players after {max_attempts} attempts. "
126
+ f"Last attempt returned {len(results)} players."
127
+ )
src/ml/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Machine learning module for game prediction."""
2
+
3
+ from src.ml.model import (
4
+ ModelLoadError,
5
+ analyze_team_stats,
6
+ get_winner_model,
7
+ predict_winner,
8
+ )
9
+
10
+ __all__ = [
11
+ "ModelLoadError",
12
+ "analyze_team_stats",
13
+ "get_winner_model",
14
+ "predict_winner",
15
+ ]
src/ml/model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Machine learning model loading and prediction with caching."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import streamlit as st
8
+ from tensorflow.keras.models import Model, load_model
9
+
10
+ logger = logging.getLogger("streamlit_nba")
11
+
12
+ # Default model path
13
+ DEFAULT_MODEL_PATH = Path("winner.keras")
14
+
15
+
16
+ class ModelLoadError(Exception):
17
+ """Raised when model loading fails."""
18
+
19
+ pass
20
+
21
+
22
+ @st.cache_resource
23
+ def get_winner_model(model_path: str | Path = DEFAULT_MODEL_PATH) -> Model:
24
+ """Load and cache the winner prediction model.
25
+
26
+ Uses Streamlit's cache_resource to ensure model is only loaded once
27
+ per session, improving performance significantly.
28
+
29
+ Args:
30
+ model_path: Path to the Keras model file
31
+
32
+ Returns:
33
+ Loaded Keras model
34
+
35
+ Raises:
36
+ ModelLoadError: If model cannot be loaded
37
+ """
38
+ path = Path(model_path)
39
+ if not path.exists():
40
+ logger.error(f"Model file not found: {path}")
41
+ raise ModelLoadError(f"Model file not found: {path}")
42
+
43
+ try:
44
+ logger.info(f"Loading model from {path}")
45
+ model = load_model(str(path))
46
+ logger.info("Model loaded successfully")
47
+ return model
48
+ except Exception as e:
49
+ logger.error(f"Failed to load model: {e}")
50
+ raise ModelLoadError(f"Failed to load model: {e}") from e
51
+
52
+
53
+ def predict_winner(combined_stats: np.ndarray) -> tuple[float, int]:
54
+ """Predict game winner from combined team stats.
55
+
56
+ Args:
57
+ combined_stats: Numpy array of shape (1, 100) containing
58
+ home team stats followed by away team stats
59
+
60
+ Returns:
61
+ Tuple of (probability, prediction) where:
62
+ - probability: Float between 0-1 (sigmoid output)
63
+ - prediction: 0 (away wins) or 1 (home wins)
64
+
65
+ Raises:
66
+ ModelLoadError: If model cannot be loaded
67
+ ValueError: If input shape is invalid
68
+ """
69
+ if combined_stats.shape != (1, 100):
70
+ raise ValueError(
71
+ f"Expected input shape (1, 100), got {combined_stats.shape}"
72
+ )
73
+
74
+ model = get_winner_model()
75
+ sigmoid_output = model.predict(combined_stats, verbose=0)
76
+ probability = float(sigmoid_output[0][0])
77
+ prediction = int(np.round(probability))
78
+
79
+ logger.info(f"Prediction: probability={probability:.4f}, winner={prediction}")
80
+ return probability, prediction
81
+
82
+
83
+ def analyze_team_stats(
84
+ home_stats: list[list[float]], away_stats: list[list[float]]
85
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
86
+ """Prepare team stats for model prediction.
87
+
88
+ Flattens per-player stats into team-level arrays suitable for
89
+ the prediction model.
90
+
91
+ Args:
92
+ home_stats: List of stat lists for each home player
93
+ away_stats: List of stat lists for each away player
94
+
95
+ Returns:
96
+ Tuple of (home_array, away_array, combined_array) where:
97
+ - home_array: Shape (1, 50) - home team flattened stats
98
+ - away_array: Shape (1, 50) - away team flattened stats
99
+ - combined_array: Shape (1, 100) - both teams for prediction
100
+ """
101
+ home_flat: list[float] = []
102
+ away_flat: list[float] = []
103
+
104
+ for player_stats in home_stats:
105
+ home_flat.extend(player_stats)
106
+
107
+ for player_stats in away_stats:
108
+ away_flat.extend(player_stats)
109
+
110
+ home_array = np.array(home_flat).reshape(1, -1)
111
+ away_array = np.array(away_flat).reshape(1, -1)
112
+ combined_array = np.array(home_flat + away_flat).reshape(1, -1)
113
+
114
+ return home_array, away_array, combined_array
src/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Pydantic models for data validation."""
2
+
3
+ from src.models.player import PlayerStats, DifficultySettings
4
+
5
+ __all__ = ["PlayerStats", "DifficultySettings"]
src/models/player.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for player and game data."""
2
+
3
+ from typing import ClassVar
4
+
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
+ from src.config import DIFFICULTY_PRESETS
8
+
9
+
10
+ class PlayerStats(BaseModel):
11
+ """Model representing a player's career statistics."""
12
+
13
+ full_name: str = Field(..., min_length=1, max_length=100)
14
+ ast: int = Field(..., ge=0, description="Career assists")
15
+ blk: int = Field(..., ge=0, description="Career blocks")
16
+ dreb: int = Field(..., ge=0, description="Career defensive rebounds")
17
+ fg3a: int = Field(..., ge=0, description="Career 3-point attempts")
18
+ fg3m: int = Field(..., ge=0, description="Career 3-pointers made")
19
+ fg3_pct: float = Field(..., ge=0.0, le=1.0, description="3-point percentage")
20
+ fga: int = Field(..., ge=0, description="Career field goal attempts")
21
+ fgm: int = Field(..., ge=0, description="Career field goals made")
22
+ fg_pct: float = Field(..., ge=0.0, le=1.0, description="Field goal percentage")
23
+ fta: int = Field(..., ge=0, description="Career free throw attempts")
24
+ ftm: int = Field(..., ge=0, description="Career free throws made")
25
+ ft_pct: float = Field(..., ge=0.0, le=1.0, description="Free throw percentage")
26
+ gp: int = Field(..., ge=0, description="Games played")
27
+ gs: int = Field(..., ge=0, description="Games started")
28
+ min: int = Field(..., ge=0, description="Career minutes")
29
+ oreb: int = Field(..., ge=0, description="Career offensive rebounds")
30
+ pf: int = Field(..., ge=0, description="Career personal fouls")
31
+ pts: int = Field(..., ge=0, description="Career points")
32
+ reb: int = Field(..., ge=0, description="Career rebounds")
33
+ stl: int = Field(..., ge=0, description="Career steals")
34
+ tov: int = Field(..., ge=0, description="Career turnovers")
35
+ first_name: str = Field(..., max_length=50)
36
+ last_name: str = Field(..., max_length=50)
37
+ full_name_lower: str = Field(..., max_length=100)
38
+ first_name_lower: str = Field(..., max_length=50)
39
+ last_name_lower: str = Field(..., max_length=50)
40
+ is_active: bool = Field(default=False)
41
+
42
+ @classmethod
43
+ def from_db_row(cls, row: tuple) -> "PlayerStats":
44
+ """Create PlayerStats from a database row tuple.
45
+
46
+ Args:
47
+ row: Database row tuple in PLAYER_COLUMNS order
48
+
49
+ Returns:
50
+ PlayerStats instance
51
+ """
52
+ return cls(
53
+ full_name=row[0],
54
+ ast=row[1],
55
+ blk=row[2],
56
+ dreb=row[3],
57
+ fg3a=row[4],
58
+ fg3m=row[5],
59
+ fg3_pct=row[6] or 0.0,
60
+ fga=row[7],
61
+ fgm=row[8],
62
+ fg_pct=row[9] or 0.0,
63
+ fta=row[10],
64
+ ftm=row[11],
65
+ ft_pct=row[12] or 0.0,
66
+ gp=row[13],
67
+ gs=row[14],
68
+ min=row[15],
69
+ oreb=row[16],
70
+ pf=row[17],
71
+ pts=row[18],
72
+ reb=row[19],
73
+ stl=row[20],
74
+ tov=row[21],
75
+ first_name=row[22],
76
+ last_name=row[23],
77
+ full_name_lower=row[24],
78
+ first_name_lower=row[25],
79
+ last_name_lower=row[26],
80
+ is_active=bool(row[27]) if row[27] is not None else False,
81
+ )
82
+
83
+
84
+ class DifficultySettings(BaseModel):
85
+ """Model for game difficulty settings."""
86
+
87
+ VALID_PRESETS: ClassVar[set[str]] = set(DIFFICULTY_PRESETS.keys())
88
+
89
+ name: str = Field(..., min_length=1)
90
+ pts_threshold: int = Field(..., ge=0, description="Minimum career points")
91
+ reb_threshold: int = Field(..., ge=0, description="Minimum career rebounds")
92
+ ast_threshold: int = Field(..., ge=0, description="Minimum career assists")
93
+ stl_threshold: int = Field(..., ge=0, description="Minimum career steals")
94
+
95
+ @field_validator("name")
96
+ @classmethod
97
+ def validate_preset_name(cls, v: str) -> str:
98
+ """Validate that preset name is recognized."""
99
+ if v not in cls.VALID_PRESETS:
100
+ raise ValueError(
101
+ f"Unknown difficulty preset: {v}. "
102
+ f"Valid options: {', '.join(sorted(cls.VALID_PRESETS))}"
103
+ )
104
+ return v
105
+
106
+ @classmethod
107
+ def from_preset(cls, preset_name: str) -> "DifficultySettings":
108
+ """Create DifficultySettings from a named preset.
109
+
110
+ Args:
111
+ preset_name: Name of difficulty preset (e.g., "Regular", "Dream Team")
112
+
113
+ Returns:
114
+ DifficultySettings instance
115
+
116
+ Raises:
117
+ ValueError: If preset_name is not valid
118
+ """
119
+ if preset_name not in DIFFICULTY_PRESETS:
120
+ raise ValueError(
121
+ f"Unknown difficulty preset: {preset_name}. "
122
+ f"Valid options: {', '.join(sorted(DIFFICULTY_PRESETS.keys()))}"
123
+ )
124
+ pts, reb, ast, stl = DIFFICULTY_PRESETS[preset_name]
125
+ return cls(
126
+ name=preset_name,
127
+ pts_threshold=pts,
128
+ reb_threshold=reb,
129
+ ast_threshold=ast,
130
+ stl_threshold=stl,
131
+ )
132
+
133
+ def as_tuple(self) -> tuple[int, int, int, int]:
134
+ """Return thresholds as tuple for backward compatibility.
135
+
136
+ Returns:
137
+ Tuple of (pts, reb, ast, stl) thresholds
138
+ """
139
+ return (
140
+ self.pts_threshold,
141
+ self.reb_threshold,
142
+ self.ast_threshold,
143
+ self.stl_threshold,
144
+ )
src/state/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Session state management module."""
2
+
3
+ from src.state.session import GameState, init_session_state, get_away_stats
4
+
5
+ __all__ = ["GameState", "init_session_state", "get_away_stats"]
src/state/session.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session state management for the Streamlit application."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass, field
5
+
6
+ import pandas as pd
7
+ import streamlit as st
8
+
9
+ from src.config import DIFFICULTY_PRESETS
10
+
11
+ logger = logging.getLogger("streamlit_nba")
12
+
13
+ # Default difficulty preset
14
+ DEFAULT_DIFFICULTY = "Regular"
15
+
16
+
17
+ @dataclass
18
+ class GameState:
19
+ """Dataclass representing the game session state."""
20
+
21
+ home_team: list[str] = field(default_factory=list)
22
+ away_team: list[str] = field(default_factory=list)
23
+ away_stats: list[int] = field(
24
+ default_factory=lambda: list(DIFFICULTY_PRESETS[DEFAULT_DIFFICULTY])
25
+ )
26
+ home_team_df: pd.DataFrame = field(default_factory=pd.DataFrame)
27
+ radio_index: int = 0
28
+
29
+
30
+ def init_session_state() -> None:
31
+ """Initialize all session state keys with safe defaults.
32
+
33
+ This should be called at the start of each page to ensure
34
+ all required state keys exist before access.
35
+ """
36
+ defaults = {
37
+ "home_team": [],
38
+ "away_team": [],
39
+ "away_stats": list(DIFFICULTY_PRESETS[DEFAULT_DIFFICULTY]),
40
+ "home_team_df": pd.DataFrame(),
41
+ "radio_index": 0,
42
+ }
43
+
44
+ for key, default_value in defaults.items():
45
+ if key not in st.session_state:
46
+ st.session_state[key] = default_value
47
+ logger.debug(f"Initialized session state: {key}")
48
+
49
+
50
+ def get_away_stats() -> list[int]:
51
+ """Safely get away team stats from session state.
52
+
53
+ Returns:
54
+ List of stat thresholds [pts, reb, ast, stl], or defaults if not set
55
+ """
56
+ init_session_state() # Ensure state is initialized
57
+ stats = st.session_state.get("away_stats")
58
+
59
+ if stats is None or not isinstance(stats, list) or len(stats) != 4:
60
+ logger.warning("Invalid away_stats in session, using defaults")
61
+ default_stats = list(DIFFICULTY_PRESETS[DEFAULT_DIFFICULTY])
62
+ st.session_state["away_stats"] = default_stats
63
+ return default_stats
64
+
65
+ return stats
66
+
67
+
68
+ def get_home_team_df() -> pd.DataFrame:
69
+ """Safely get home team DataFrame from session state.
70
+
71
+ Returns:
72
+ DataFrame with home team player data, or empty DataFrame if not set
73
+ """
74
+ init_session_state()
75
+ df = st.session_state.get("home_team_df")
76
+
77
+ if df is None or not isinstance(df, pd.DataFrame):
78
+ logger.warning("Invalid home_team_df in session, using empty DataFrame")
79
+ return pd.DataFrame()
80
+
81
+ return df
82
+
83
+
84
+ def get_home_team_names() -> list[str]:
85
+ """Safely get home team player names from session state.
86
+
87
+ Returns:
88
+ List of player names on home team
89
+ """
90
+ init_session_state()
91
+ team = st.session_state.get("home_team")
92
+
93
+ if team is None or not isinstance(team, list):
94
+ return []
95
+
96
+ return team
97
+
98
+
99
+ def set_difficulty(preset_name: str) -> None:
100
+ """Set the difficulty level by preset name.
101
+
102
+ Args:
103
+ preset_name: Name of difficulty preset
104
+ """
105
+ if preset_name not in DIFFICULTY_PRESETS:
106
+ logger.error(f"Invalid difficulty preset: {preset_name}")
107
+ return
108
+
109
+ index = list(DIFFICULTY_PRESETS.keys()).index(preset_name)
110
+ st.session_state["away_stats"] = list(DIFFICULTY_PRESETS[preset_name])
111
+ st.session_state["radio_index"] = index
112
+ logger.info(f"Set difficulty to {preset_name}")
113
+
114
+
115
+ def add_player_to_team(player_name: str) -> bool:
116
+ """Add a player to the home team.
117
+
118
+ Args:
119
+ player_name: Full name of player to add
120
+
121
+ Returns:
122
+ True if added, False if already on team or team is full
123
+ """
124
+ init_session_state()
125
+ team = st.session_state.get("home_team", [])
126
+
127
+ if len(team) >= 5:
128
+ logger.warning("Cannot add player: team is full")
129
+ return False
130
+
131
+ if player_name in team:
132
+ logger.debug(f"Player {player_name} already on team")
133
+ return False
134
+
135
+ team.append(player_name)
136
+ st.session_state["home_team"] = team
137
+ logger.info(f"Added {player_name} to team")
138
+ return True
139
+
140
+
141
+ def remove_player_from_team(player_name: str) -> bool:
142
+ """Remove a player from the home team.
143
+
144
+ Args:
145
+ player_name: Full name of player to remove
146
+
147
+ Returns:
148
+ True if removed, False if not on team
149
+ """
150
+ init_session_state()
151
+ team = st.session_state.get("home_team", [])
152
+
153
+ if player_name not in team:
154
+ logger.debug(f"Player {player_name} not on team")
155
+ return False
156
+
157
+ team.remove(player_name)
158
+ st.session_state["home_team"] = team
159
+ logger.info(f"Removed {player_name} from team")
160
+ return True
src/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Utility functions for HTML sanitization and other helpers."""
2
+
3
+ from src.utils.html import safe_heading, safe_paragraph
4
+
5
+ __all__ = ["safe_heading", "safe_paragraph"]
src/utils/html.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTML sanitization utilities for XSS protection."""
2
+
3
+ import html
4
+ from typing import Literal
5
+
6
+ import streamlit as st
7
+
8
+ # Valid heading levels
9
+ HeadingLevel = Literal[1, 2, 3, 4, 5, 6]
10
+
11
+
12
+ def escape_html(text: str) -> str:
13
+ """Escape HTML special characters to prevent XSS.
14
+
15
+ Args:
16
+ text: Raw text that may contain HTML
17
+
18
+ Returns:
19
+ Escaped text safe for HTML insertion
20
+ """
21
+ return html.escape(str(text))
22
+
23
+
24
+ def safe_heading(
25
+ text: str,
26
+ level: HeadingLevel = 1,
27
+ color: str = "steelblue",
28
+ align: str = "center",
29
+ ) -> None:
30
+ """Render a heading with escaped text to prevent XSS.
31
+
32
+ Args:
33
+ text: Heading text (will be escaped)
34
+ level: Heading level 1-6
35
+ color: CSS color value
36
+ align: CSS text-align value
37
+ """
38
+ # Escape all user-provided values
39
+ safe_text = escape_html(text)
40
+ safe_color = escape_html(color)
41
+ safe_align = escape_html(align)
42
+
43
+ st.markdown(
44
+ f"<h{level} style='text-align: {safe_align}; color: {safe_color};'>"
45
+ f"{safe_text}</h{level}>",
46
+ unsafe_allow_html=True,
47
+ )
48
+
49
+
50
+ def safe_paragraph(
51
+ text: str,
52
+ color: str = "white",
53
+ align: str = "center",
54
+ ) -> None:
55
+ """Render a paragraph with escaped text to prevent XSS.
56
+
57
+ Args:
58
+ text: Paragraph text (will be escaped)
59
+ color: CSS color value
60
+ align: CSS text-align value
61
+ """
62
+ safe_text = escape_html(text)
63
+ safe_color = escape_html(color)
64
+ safe_align = escape_html(align)
65
+
66
+ st.markdown(
67
+ f"<p style='text-align: {safe_align}; color: {safe_color};'>"
68
+ f"{safe_text}</p>",
69
+ unsafe_allow_html=True,
70
+ )
71
+
72
+
73
+ def safe_styled_text(
74
+ text: str,
75
+ tag: str = "span",
76
+ color: str | None = None,
77
+ align: str | None = None,
78
+ **styles: str,
79
+ ) -> str:
80
+ """Generate HTML string with escaped text and validated styles.
81
+
82
+ Args:
83
+ text: Text content (will be escaped)
84
+ tag: HTML tag to use
85
+ color: Optional CSS color
86
+ align: Optional CSS text-align
87
+ **styles: Additional CSS properties
88
+
89
+ Returns:
90
+ Safe HTML string
91
+ """
92
+ safe_text = escape_html(text)
93
+ safe_tag = escape_html(tag)
94
+
95
+ style_parts: list[str] = []
96
+ if color:
97
+ style_parts.append(f"color: {escape_html(color)}")
98
+ if align:
99
+ style_parts.append(f"text-align: {escape_html(align)}")
100
+ for prop, value in styles.items():
101
+ # Convert underscores to hyphens for CSS properties
102
+ css_prop = prop.replace("_", "-")
103
+ style_parts.append(f"{escape_html(css_prop)}: {escape_html(value)}")
104
+
105
+ style_str = "; ".join(style_parts)
106
+ if style_str:
107
+ return f"<{safe_tag} style='{style_str}'>{safe_text}</{safe_tag}>"
108
+ return f"<{safe_tag}>{safe_text}</{safe_tag}>"
src/validation/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Input validation module."""
2
+
3
+ from src.validation.inputs import (
4
+ PlayerSearchInput,
5
+ is_valid_search_term,
6
+ validate_search_term,
7
+ )
8
+
9
+ __all__ = ["PlayerSearchInput", "is_valid_search_term", "validate_search_term"]
src/validation/inputs.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Input validation for user-provided data."""
2
+
3
+ import re
4
+
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
+ # Patterns that indicate SQL injection attempts
8
+ SQL_INJECTION_PATTERNS: list[str] = [
9
+ r"['\";]", # Quote characters and semicolons
10
+ r"--", # SQL comment
11
+ r"/\*", # Block comment start
12
+ r"\*/", # Block comment end
13
+ r"\bUNION\b", # UNION keyword
14
+ r"\bSELECT\b", # SELECT keyword
15
+ r"\bINSERT\b", # INSERT keyword
16
+ r"\bUPDATE\b", # UPDATE keyword
17
+ r"\bDELETE\b", # DELETE keyword
18
+ r"\bDROP\b", # DROP keyword
19
+ r"\bEXEC\b", # EXEC keyword
20
+ r"\bOR\s+\d+=\d+", # OR 1=1 pattern
21
+ r"\bAND\s+\d+=\d+", # AND 1=1 pattern
22
+ ]
23
+
24
+ # Compiled regex for efficiency
25
+ SQL_INJECTION_REGEX = re.compile(
26
+ "|".join(SQL_INJECTION_PATTERNS), re.IGNORECASE
27
+ )
28
+
29
+
30
+ class PlayerSearchInput(BaseModel):
31
+ """Validated player search input."""
32
+
33
+ search_term: str = Field(
34
+ ...,
35
+ min_length=1,
36
+ max_length=100,
37
+ description="Player name search term",
38
+ )
39
+
40
+ @field_validator("search_term")
41
+ @classmethod
42
+ def validate_no_sql_injection(cls, v: str) -> str:
43
+ """Reject inputs containing SQL injection patterns.
44
+
45
+ Args:
46
+ v: Input search term
47
+
48
+ Returns:
49
+ Validated search term
50
+
51
+ Raises:
52
+ ValueError: If SQL injection pattern detected
53
+ """
54
+ if SQL_INJECTION_REGEX.search(v):
55
+ raise ValueError(
56
+ "Invalid characters in search term. "
57
+ "Please use only letters, numbers, spaces, and hyphens."
58
+ )
59
+ return v.strip()
60
+
61
+ @field_validator("search_term")
62
+ @classmethod
63
+ def validate_reasonable_characters(cls, v: str) -> str:
64
+ """Ensure search term contains only reasonable characters.
65
+
66
+ Args:
67
+ v: Input search term
68
+
69
+ Returns:
70
+ Validated search term
71
+
72
+ Raises:
73
+ ValueError: If invalid characters found
74
+ """
75
+ # Allow letters, numbers, spaces, hyphens, periods, and apostrophes
76
+ # (e.g., "O'Neal", "J.R. Smith")
77
+ if not re.match(r"^[a-zA-Z0-9\s\-.']+$", v):
78
+ raise ValueError(
79
+ "Search term contains invalid characters. "
80
+ "Please use only letters, numbers, spaces, hyphens, "
81
+ "periods, and apostrophes."
82
+ )
83
+ return v
84
+
85
+
86
+ def validate_search_term(term: str) -> str | None:
87
+ """Validate a player search term.
88
+
89
+ Args:
90
+ term: Raw search input
91
+
92
+ Returns:
93
+ Validated and cleaned search term, or None if invalid
94
+ """
95
+ try:
96
+ validated = PlayerSearchInput(search_term=term)
97
+ return validated.search_term
98
+ except ValueError:
99
+ return None
100
+
101
+
102
+ def is_valid_search_term(term: str) -> bool:
103
+ """Check if a search term is valid without raising exceptions.
104
+
105
+ Args:
106
+ term: Raw search input
107
+
108
+ Returns:
109
+ True if valid, False otherwise
110
+ """
111
+ return validate_search_term(term) is not None
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Test suite for NBA Streamlit application."""
tests/conftest.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytest fixtures for NBA Streamlit application tests."""
2
+
3
+ from typing import Any
4
+ from unittest.mock import MagicMock
5
+
6
+ import pandas as pd
7
+ import pytest
8
+
9
+
10
+ @pytest.fixture
11
+ def mock_snowflake_connection() -> MagicMock:
12
+ """Create a mock Snowflake connection.
13
+
14
+ Returns:
15
+ Mock connection object that simulates Snowflake connection behavior
16
+ """
17
+ mock_conn = MagicMock()
18
+ mock_cursor = MagicMock()
19
+ mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
20
+ mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
21
+ return mock_conn
22
+
23
+
24
+ @pytest.fixture
25
+ def sample_player_data() -> list[tuple[Any, ...]]:
26
+ """Create sample player data matching database schema.
27
+
28
+ Returns:
29
+ List of tuples with sample player data
30
+ """
31
+ return [
32
+ (
33
+ "LeBron James", # FULL_NAME
34
+ 10141, # AST
35
+ 1107, # BLK
36
+ 5972, # DREB
37
+ 2891, # FG3A
38
+ 1043, # FG3M
39
+ 0.361, # FG3_PCT
40
+ 24856, # FGA
41
+ 12621, # FGM
42
+ 0.508, # FG_PCT
43
+ 11067, # FTA
44
+ 7938, # FTM
45
+ 0.717, # FT_PCT
46
+ 1421, # GP
47
+ 1421, # GS
48
+ 54218, # MIN
49
+ 1663, # OREB
50
+ 2159, # PF
51
+ 39223, # PTS
52
+ 10988, # REB
53
+ 2219, # STL
54
+ 5015, # TOV
55
+ "LeBron", # FIRST_NAME
56
+ "James", # LAST_NAME
57
+ "lebron james", # FULL_NAME_LOWER
58
+ "lebron", # FIRST_NAME_LOWER
59
+ "james", # LAST_NAME_LOWER
60
+ True, # IS_ACTIVE
61
+ ),
62
+ (
63
+ "Michael Jordan",
64
+ 5633,
65
+ 893,
66
+ 4578,
67
+ 1778,
68
+ 581,
69
+ 0.327,
70
+ 24537,
71
+ 12192,
72
+ 0.497,
73
+ 8772,
74
+ 7327,
75
+ 0.835,
76
+ 1072,
77
+ 1039,
78
+ 41011,
79
+ 1463,
80
+ 2783,
81
+ 32292,
82
+ 6672,
83
+ 2514,
84
+ 2924,
85
+ "Michael",
86
+ "Jordan",
87
+ "michael jordan",
88
+ "michael",
89
+ "jordan",
90
+ False,
91
+ ),
92
+ ]
93
+
94
+
95
+ @pytest.fixture
96
+ def sample_player_df(sample_player_data: list[tuple]) -> pd.DataFrame:
97
+ """Create sample player DataFrame.
98
+
99
+ Args:
100
+ sample_player_data: List of player tuples
101
+
102
+ Returns:
103
+ DataFrame with sample player data
104
+ """
105
+ from src.config import PLAYER_COLUMNS
106
+
107
+ return pd.DataFrame(sample_player_data, columns=PLAYER_COLUMNS)
108
+
109
+
110
+ @pytest.fixture
111
+ def sample_team_stats() -> list[list[float]]:
112
+ """Create sample team stats for ML model input.
113
+
114
+ Returns:
115
+ List of player stat lists (5 players x 10 stats)
116
+ """
117
+ return [
118
+ [1500.0, 100.0, 200.0, 300.0, 50.0, 30.0, 100.0, 0.35, 0.80, 500.0],
119
+ [1200.0, 80.0, 180.0, 250.0, 40.0, 25.0, 90.0, 0.38, 0.75, 450.0],
120
+ [1000.0, 60.0, 150.0, 200.0, 35.0, 20.0, 80.0, 0.40, 0.82, 400.0],
121
+ [800.0, 50.0, 120.0, 150.0, 30.0, 15.0, 70.0, 0.33, 0.78, 350.0],
122
+ [600.0, 40.0, 100.0, 100.0, 25.0, 10.0, 60.0, 0.36, 0.85, 300.0],
123
+ ]
tests/test_database.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for database module."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pandas as pd
6
+ import pytest
7
+
8
+ from src.config import MAX_QUERY_ATTEMPTS, PLAYER_COLUMNS
9
+ from src.database.connection import QueryExecutionError
10
+ from src.database.queries import (
11
+ get_away_team_by_stats,
12
+ get_player_by_full_name,
13
+ get_players_by_full_names,
14
+ search_player_by_name,
15
+ )
16
+
17
+
18
+ class TestSearchPlayerByName:
19
+ """Tests for search_player_by_name function."""
20
+
21
+ def test_uses_parameterized_query(
22
+ self, mock_snowflake_connection: MagicMock
23
+ ) -> None:
24
+ """Verify parameterized queries are used (not string formatting)."""
25
+ mock_cursor = MagicMock()
26
+ mock_cursor.fetchall.return_value = [("LeBron James",)]
27
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
28
+ mock_cursor
29
+ )
30
+
31
+ search_player_by_name(mock_snowflake_connection, "james")
32
+
33
+ # Verify execute was called with params tuple, not string formatting
34
+ mock_cursor.execute.assert_called_once()
35
+ call_args = mock_cursor.execute.call_args
36
+ query = call_args[0][0]
37
+ params = call_args[0][1]
38
+
39
+ # Query should use %s placeholders
40
+ assert "%s" in query
41
+ # Should not contain the actual search term in the query string
42
+ assert "james" not in query.lower()
43
+ # Params should be a tuple with the search term
44
+ assert params == ("james", "james", "james")
45
+
46
+ def test_returns_list_of_tuples(
47
+ self, mock_snowflake_connection: MagicMock
48
+ ) -> None:
49
+ """Test that results are returned as list of tuples."""
50
+ mock_cursor = MagicMock()
51
+ mock_cursor.fetchall.return_value = [
52
+ ("LeBron James",),
53
+ ("James Harden",),
54
+ ]
55
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
56
+ mock_cursor
57
+ )
58
+
59
+ result = search_player_by_name(mock_snowflake_connection, "james")
60
+
61
+ assert result == [("LeBron James",), ("James Harden",)]
62
+
63
+
64
+ class TestGetPlayersByFullNames:
65
+ """Tests for get_players_by_full_names batch query."""
66
+
67
+ def test_single_query_for_multiple_names(
68
+ self, mock_snowflake_connection: MagicMock, sample_player_data: list
69
+ ) -> None:
70
+ """Verify batch query uses single IN clause instead of N queries."""
71
+ mock_cursor = MagicMock()
72
+ mock_cursor.fetchall.return_value = sample_player_data
73
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
74
+ mock_cursor
75
+ )
76
+
77
+ names = ["LeBron James", "Michael Jordan"]
78
+ get_players_by_full_names(mock_snowflake_connection, names)
79
+
80
+ # Should only execute one query
81
+ assert mock_cursor.execute.call_count == 1
82
+
83
+ call_args = mock_cursor.execute.call_args
84
+ query = call_args[0][0]
85
+ params = call_args[0][1]
86
+
87
+ # Query should have IN clause with placeholders
88
+ assert "IN" in query.upper()
89
+ assert "%s" in query
90
+ # Params should be tuple of names
91
+ assert params == ("LeBron James", "Michael Jordan")
92
+
93
+ def test_returns_dataframe(
94
+ self, mock_snowflake_connection: MagicMock, sample_player_data: list
95
+ ) -> None:
96
+ """Test that results are returned as DataFrame."""
97
+ mock_cursor = MagicMock()
98
+ mock_cursor.fetchall.return_value = sample_player_data
99
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
100
+ mock_cursor
101
+ )
102
+
103
+ result = get_players_by_full_names(
104
+ mock_snowflake_connection, ["LeBron James", "Michael Jordan"]
105
+ )
106
+
107
+ assert isinstance(result, pd.DataFrame)
108
+ assert list(result.columns) == PLAYER_COLUMNS
109
+ assert len(result) == 2
110
+
111
+ def test_empty_names_returns_empty_dataframe(
112
+ self, mock_snowflake_connection: MagicMock
113
+ ) -> None:
114
+ """Test that empty input returns empty DataFrame without query."""
115
+ mock_cursor = MagicMock()
116
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
117
+ mock_cursor
118
+ )
119
+
120
+ result = get_players_by_full_names(mock_snowflake_connection, [])
121
+
122
+ assert isinstance(result, pd.DataFrame)
123
+ assert result.empty
124
+ # Should not execute any query
125
+ mock_cursor.execute.assert_not_called()
126
+
127
+
128
+ class TestGetAwayTeamByStats:
129
+ """Tests for get_away_team_by_stats with max_attempts guard."""
130
+
131
+ def test_max_attempts_raises_error(
132
+ self, mock_snowflake_connection: MagicMock
133
+ ) -> None:
134
+ """Test that max_attempts limit prevents infinite loop."""
135
+ mock_cursor = MagicMock()
136
+ # Always return wrong number of players
137
+ mock_cursor.fetchall.return_value = [("Player1",), ("Player2",)]
138
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
139
+ mock_cursor
140
+ )
141
+
142
+ with pytest.raises(QueryExecutionError) as exc_info:
143
+ get_away_team_by_stats(
144
+ mock_snowflake_connection,
145
+ pts_threshold=1000,
146
+ reb_threshold=500,
147
+ ast_threshold=300,
148
+ stl_threshold=100,
149
+ max_attempts=3,
150
+ )
151
+
152
+ assert "3 attempts" in str(exc_info.value)
153
+ assert mock_cursor.execute.call_count == 3
154
+
155
+ def test_success_on_first_try(
156
+ self, mock_snowflake_connection: MagicMock, sample_player_data: list
157
+ ) -> None:
158
+ """Test successful query on first attempt."""
159
+ mock_cursor = MagicMock()
160
+ # Return exactly 5 players
161
+ mock_cursor.fetchall.return_value = sample_player_data * 3 # 6 players
162
+ mock_cursor.fetchall.return_value = [
163
+ sample_player_data[0],
164
+ sample_player_data[1],
165
+ sample_player_data[0],
166
+ sample_player_data[1],
167
+ sample_player_data[0],
168
+ ]
169
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
170
+ mock_cursor
171
+ )
172
+
173
+ result = get_away_team_by_stats(
174
+ mock_snowflake_connection,
175
+ pts_threshold=1000,
176
+ reb_threshold=500,
177
+ ast_threshold=300,
178
+ stl_threshold=100,
179
+ )
180
+
181
+ assert isinstance(result, pd.DataFrame)
182
+ assert len(result) == 5
183
+ # Should only need one query
184
+ assert mock_cursor.execute.call_count == 1
185
+
186
+ def test_uses_parameterized_query(
187
+ self, mock_snowflake_connection: MagicMock, sample_player_data: list
188
+ ) -> None:
189
+ """Verify parameterized queries are used for stat thresholds."""
190
+ mock_cursor = MagicMock()
191
+ mock_cursor.fetchall.return_value = [
192
+ sample_player_data[0],
193
+ sample_player_data[1],
194
+ sample_player_data[0],
195
+ sample_player_data[1],
196
+ sample_player_data[0],
197
+ ]
198
+ mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
199
+ mock_cursor
200
+ )
201
+
202
+ get_away_team_by_stats(
203
+ mock_snowflake_connection,
204
+ pts_threshold=1000,
205
+ reb_threshold=500,
206
+ ast_threshold=300,
207
+ stl_threshold=100,
208
+ )
209
+
210
+ call_args = mock_cursor.execute.call_args
211
+ query = call_args[0][0]
212
+ params = call_args[0][1]
213
+
214
+ # Query should use %s placeholders
215
+ assert "%s" in query
216
+ # Should not contain actual numbers in query
217
+ assert "1000" not in query
218
+ assert "500" not in query
219
+ # Params should be tuple of thresholds
220
+ assert params == (1000, 500, 300, 100)
tests/test_ml.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ML model module."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ from src.ml.model import ModelLoadError, analyze_team_stats, predict_winner
9
+
10
+
11
+ class TestAnalyzeTeamStats:
12
+ """Tests for analyze_team_stats function."""
13
+
14
+ def test_flattens_stats_correctly(
15
+ self, sample_team_stats: list[list[float]]
16
+ ) -> None:
17
+ """Test that team stats are flattened correctly."""
18
+ home_array, away_array, combined = analyze_team_stats(
19
+ sample_team_stats, sample_team_stats
20
+ )
21
+
22
+ # Each team has 5 players x 10 stats = 50 values
23
+ assert home_array.shape == (1, 50)
24
+ assert away_array.shape == (1, 50)
25
+ # Combined has both teams = 100 values
26
+ assert combined.shape == (1, 100)
27
+
28
+ def test_combined_contains_both_teams(
29
+ self, sample_team_stats: list[list[float]]
30
+ ) -> None:
31
+ """Test that combined array contains both teams' stats."""
32
+ home_stats = [[1.0, 2.0], [3.0, 4.0]] # 2 players, 2 stats each
33
+ away_stats = [[5.0, 6.0], [7.0, 8.0]]
34
+
35
+ home_array, away_array, combined = analyze_team_stats(
36
+ home_stats, away_stats
37
+ )
38
+
39
+ # Home should be first 4 values, away next 4
40
+ np.testing.assert_array_equal(
41
+ combined[0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
42
+ )
43
+
44
+
45
+ class TestPredictWinner:
46
+ """Tests for predict_winner function."""
47
+
48
+ @patch("src.ml.model.get_winner_model")
49
+ def test_returns_probability_and_prediction(
50
+ self, mock_get_model: MagicMock
51
+ ) -> None:
52
+ """Test that function returns (probability, prediction) tuple."""
53
+ mock_model = MagicMock()
54
+ mock_model.predict.return_value = np.array([[0.75]])
55
+ mock_get_model.return_value = mock_model
56
+
57
+ stats = np.random.rand(1, 100)
58
+ probability, prediction = predict_winner(stats)
59
+
60
+ assert isinstance(probability, float)
61
+ assert isinstance(prediction, int)
62
+ assert 0.0 <= probability <= 1.0
63
+ assert prediction in (0, 1)
64
+
65
+ @patch("src.ml.model.get_winner_model")
66
+ def test_high_probability_predicts_win(
67
+ self, mock_get_model: MagicMock
68
+ ) -> None:
69
+ """Test that high probability (>0.5) predicts home win (1)."""
70
+ mock_model = MagicMock()
71
+ mock_model.predict.return_value = np.array([[0.8]])
72
+ mock_get_model.return_value = mock_model
73
+
74
+ stats = np.random.rand(1, 100)
75
+ probability, prediction = predict_winner(stats)
76
+
77
+ assert probability == 0.8
78
+ assert prediction == 1
79
+
80
+ @patch("src.ml.model.get_winner_model")
81
+ def test_low_probability_predicts_loss(
82
+ self, mock_get_model: MagicMock
83
+ ) -> None:
84
+ """Test that low probability (<0.5) predicts home loss (0)."""
85
+ mock_model = MagicMock()
86
+ mock_model.predict.return_value = np.array([[0.3]])
87
+ mock_get_model.return_value = mock_model
88
+
89
+ stats = np.random.rand(1, 100)
90
+ probability, prediction = predict_winner(stats)
91
+
92
+ assert probability == 0.3
93
+ assert prediction == 0
94
+
95
+ @patch("src.ml.model.get_winner_model")
96
+ def test_invalid_shape_raises_error(
97
+ self, mock_get_model: MagicMock
98
+ ) -> None:
99
+ """Test that invalid input shape raises ValueError."""
100
+ mock_model = MagicMock()
101
+ mock_get_model.return_value = mock_model
102
+
103
+ # Wrong shape
104
+ stats = np.random.rand(1, 50)
105
+
106
+ with pytest.raises(ValueError) as exc_info:
107
+ predict_winner(stats)
108
+
109
+ assert "Expected input shape (1, 100)" in str(exc_info.value)
110
+
111
+ @patch("src.ml.model.get_winner_model")
112
+ def test_model_called_with_verbose_zero(
113
+ self, mock_get_model: MagicMock
114
+ ) -> None:
115
+ """Test that model.predict is called with verbose=0."""
116
+ mock_model = MagicMock()
117
+ mock_model.predict.return_value = np.array([[0.5]])
118
+ mock_get_model.return_value = mock_model
119
+
120
+ stats = np.random.rand(1, 100)
121
+ predict_winner(stats)
122
+
123
+ mock_model.predict.assert_called_once_with(stats, verbose=0)
124
+
125
+
126
+ class TestGetWinnerModel:
127
+ """Tests for get_winner_model caching."""
128
+
129
+ @patch("src.ml.model.load_model")
130
+ @patch("src.ml.model.Path")
131
+ def test_raises_error_for_missing_model(
132
+ self, mock_path: MagicMock, mock_load: MagicMock
133
+ ) -> None:
134
+ """Test that missing model file raises ModelLoadError."""
135
+ from src.ml.model import get_winner_model
136
+
137
+ # Clear the cache to ensure fresh test
138
+ get_winner_model.clear()
139
+
140
+ mock_path_instance = MagicMock()
141
+ mock_path_instance.exists.return_value = False
142
+ mock_path.return_value = mock_path_instance
143
+
144
+ with pytest.raises(ModelLoadError) as exc_info:
145
+ get_winner_model("nonexistent.keras")
146
+
147
+ assert "not found" in str(exc_info.value)
tests/test_models.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Pydantic models."""
2
+
3
+ import pytest
4
+
5
+ from src.config import DIFFICULTY_PRESETS
6
+ from src.models.player import DifficultySettings, PlayerStats
7
+
8
+
9
+ class TestPlayerStats:
10
+ """Tests for PlayerStats model."""
11
+
12
+ def test_from_db_row(self, sample_player_data: list) -> None:
13
+ """Test creating PlayerStats from database row tuple."""
14
+ row = sample_player_data[0] # LeBron James data
15
+
16
+ player = PlayerStats.from_db_row(row)
17
+
18
+ assert player.full_name == "LeBron James"
19
+ assert player.pts == 39223
20
+ assert player.ast == 10141
21
+ assert player.is_active is True
22
+
23
+ def test_validates_negative_stats(self) -> None:
24
+ """Test that negative stats are rejected."""
25
+ with pytest.raises(ValueError):
26
+ PlayerStats(
27
+ full_name="Test Player",
28
+ ast=-1, # Invalid
29
+ blk=0,
30
+ dreb=0,
31
+ fg3a=0,
32
+ fg3m=0,
33
+ fg3_pct=0.0,
34
+ fga=0,
35
+ fgm=0,
36
+ fg_pct=0.0,
37
+ fta=0,
38
+ ftm=0,
39
+ ft_pct=0.0,
40
+ gp=0,
41
+ gs=0,
42
+ min=0,
43
+ oreb=0,
44
+ pf=0,
45
+ pts=0,
46
+ reb=0,
47
+ stl=0,
48
+ tov=0,
49
+ first_name="Test",
50
+ last_name="Player",
51
+ full_name_lower="test player",
52
+ first_name_lower="test",
53
+ last_name_lower="player",
54
+ is_active=True,
55
+ )
56
+
57
+ def test_validates_percentage_range(self) -> None:
58
+ """Test that percentages must be 0-1."""
59
+ with pytest.raises(ValueError):
60
+ PlayerStats(
61
+ full_name="Test Player",
62
+ ast=0,
63
+ blk=0,
64
+ dreb=0,
65
+ fg3a=0,
66
+ fg3m=0,
67
+ fg3_pct=1.5, # Invalid - over 1.0
68
+ fga=0,
69
+ fgm=0,
70
+ fg_pct=0.0,
71
+ fta=0,
72
+ ftm=0,
73
+ ft_pct=0.0,
74
+ gp=0,
75
+ gs=0,
76
+ min=0,
77
+ oreb=0,
78
+ pf=0,
79
+ pts=0,
80
+ reb=0,
81
+ stl=0,
82
+ tov=0,
83
+ first_name="Test",
84
+ last_name="Player",
85
+ full_name_lower="test player",
86
+ first_name_lower="test",
87
+ last_name_lower="player",
88
+ is_active=True,
89
+ )
90
+
91
+
92
+ class TestDifficultySettings:
93
+ """Tests for DifficultySettings model."""
94
+
95
+ @pytest.mark.parametrize("preset_name", list(DIFFICULTY_PRESETS.keys()))
96
+ def test_from_preset_valid(self, preset_name: str) -> None:
97
+ """Test creating DifficultySettings from valid presets."""
98
+ settings = DifficultySettings.from_preset(preset_name)
99
+
100
+ assert settings.name == preset_name
101
+ expected = DIFFICULTY_PRESETS[preset_name]
102
+ assert settings.pts_threshold == expected[0]
103
+ assert settings.reb_threshold == expected[1]
104
+ assert settings.ast_threshold == expected[2]
105
+ assert settings.stl_threshold == expected[3]
106
+
107
+ def test_from_preset_invalid(self) -> None:
108
+ """Test that invalid preset name raises ValueError."""
109
+ with pytest.raises(ValueError) as exc_info:
110
+ DifficultySettings.from_preset("Invalid Preset")
111
+
112
+ assert "Unknown difficulty preset" in str(exc_info.value)
113
+
114
+ def test_as_tuple(self) -> None:
115
+ """Test converting settings to tuple."""
116
+ settings = DifficultySettings.from_preset("Regular")
117
+
118
+ result = settings.as_tuple()
119
+
120
+ assert result == DIFFICULTY_PRESETS["Regular"]
121
+ assert isinstance(result, tuple)
122
+ assert len(result) == 4
123
+
124
+ def test_regular_preset_values(self) -> None:
125
+ """Test Regular preset has expected values."""
126
+ settings = DifficultySettings.from_preset("Regular")
127
+
128
+ assert settings.pts_threshold == 850
129
+ assert settings.reb_threshold == 400
130
+ assert settings.ast_threshold == 200
131
+ assert settings.stl_threshold == 60
132
+
133
+ def test_dream_team_preset_values(self) -> None:
134
+ """Test Dream Team preset has highest values."""
135
+ settings = DifficultySettings.from_preset("Dream Team")
136
+
137
+ assert settings.pts_threshold == 1450
138
+ assert settings.reb_threshold == 700
139
+ assert settings.ast_threshold == 500
140
+ assert settings.stl_threshold == 120
tests/test_validation.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for input validation module."""
2
+
3
+ import pytest
4
+
5
+ from src.validation.inputs import (
6
+ PlayerSearchInput,
7
+ is_valid_search_term,
8
+ validate_search_term,
9
+ )
10
+
11
+
12
+ class TestPlayerSearchInput:
13
+ """Tests for PlayerSearchInput validation."""
14
+
15
+ def test_valid_simple_name(self) -> None:
16
+ """Test valid simple name passes validation."""
17
+ result = PlayerSearchInput(search_term="James")
18
+ assert result.search_term == "James"
19
+
20
+ def test_valid_full_name(self) -> None:
21
+ """Test valid full name passes validation."""
22
+ result = PlayerSearchInput(search_term="LeBron James")
23
+ assert result.search_term == "LeBron James"
24
+
25
+ def test_valid_name_with_apostrophe(self) -> None:
26
+ """Test name with apostrophe passes validation."""
27
+ result = PlayerSearchInput(search_term="Shaquille O'Neal")
28
+ assert result.search_term == "Shaquille O'Neal"
29
+
30
+ def test_valid_name_with_period(self) -> None:
31
+ """Test name with period passes validation."""
32
+ result = PlayerSearchInput(search_term="J.R. Smith")
33
+ assert result.search_term == "J.R. Smith"
34
+
35
+ def test_valid_name_with_hyphen(self) -> None:
36
+ """Test name with hyphen passes validation."""
37
+ result = PlayerSearchInput(search_term="Kareem Abdul-Jabbar")
38
+ assert result.search_term == "Kareem Abdul-Jabbar"
39
+
40
+ def test_strips_whitespace(self) -> None:
41
+ """Test that whitespace is stripped."""
42
+ result = PlayerSearchInput(search_term=" James ")
43
+ assert result.search_term == "James"
44
+
45
+
46
+ class TestSqlInjectionRejection:
47
+ """Tests for SQL injection pattern rejection."""
48
+
49
+ @pytest.mark.parametrize(
50
+ "malicious_input",
51
+ [
52
+ "'; DROP TABLE NBA;--",
53
+ "James'; DELETE FROM NBA--",
54
+ "' OR '1'='1",
55
+ "James' UNION SELECT * FROM passwords--",
56
+ "James; SELECT * FROM users",
57
+ "/*comment*/James",
58
+ "James*/DROP TABLE/*",
59
+ "' OR 1=1--",
60
+ "James' AND 1=1--",
61
+ "Robert'); DROP TABLE Students;--",
62
+ ],
63
+ )
64
+ def test_rejects_sql_injection(self, malicious_input: str) -> None:
65
+ """Test that SQL injection patterns are rejected."""
66
+ with pytest.raises(ValueError) as exc_info:
67
+ PlayerSearchInput(search_term=malicious_input)
68
+
69
+ # Should mention invalid characters
70
+ assert "Invalid" in str(exc_info.value) or "invalid" in str(exc_info.value)
71
+
72
+ @pytest.mark.parametrize(
73
+ "invalid_input",
74
+ [
75
+ "James<script>",
76
+ "James&nbsp;",
77
+ "James@#$%",
78
+ "James\\nNewline",
79
+ "James\x00null",
80
+ ],
81
+ )
82
+ def test_rejects_special_characters(self, invalid_input: str) -> None:
83
+ """Test that special characters are rejected."""
84
+ with pytest.raises(ValueError):
85
+ PlayerSearchInput(search_term=invalid_input)
86
+
87
+ def test_rejects_empty_string(self) -> None:
88
+ """Test that empty string is rejected."""
89
+ with pytest.raises(ValueError):
90
+ PlayerSearchInput(search_term="")
91
+
92
+ def test_rejects_too_long(self) -> None:
93
+ """Test that overly long input is rejected."""
94
+ with pytest.raises(ValueError):
95
+ PlayerSearchInput(search_term="A" * 101)
96
+
97
+
98
+ class TestValidateSearchTerm:
99
+ """Tests for validate_search_term helper function."""
100
+
101
+ def test_returns_cleaned_term(self) -> None:
102
+ """Test that valid term is returned cleaned."""
103
+ result = validate_search_term(" James ")
104
+ assert result == "James"
105
+
106
+ def test_returns_none_for_invalid(self) -> None:
107
+ """Test that invalid input returns None."""
108
+ result = validate_search_term("'; DROP TABLE--")
109
+ assert result is None
110
+
111
+ def test_returns_none_for_empty(self) -> None:
112
+ """Test that empty input returns None."""
113
+ result = validate_search_term("")
114
+ assert result is None
115
+
116
+
117
+ class TestIsValidSearchTerm:
118
+ """Tests for is_valid_search_term helper function."""
119
+
120
+ def test_returns_true_for_valid(self) -> None:
121
+ """Test returns True for valid input."""
122
+ assert is_valid_search_term("James") is True
123
+ assert is_valid_search_term("LeBron James") is True
124
+ assert is_valid_search_term("O'Neal") is True
125
+
126
+ def test_returns_false_for_invalid(self) -> None:
127
+ """Test returns False for invalid input."""
128
+ assert is_valid_search_term("'; DROP--") is False
129
+ assert is_valid_search_term("") is False
130
+ assert is_valid_search_term("<script>") is False