Hatmanstack commited on
Commit
20852d6
·
1 Parent(s): 92a832f

Transition data layer to local CSV using pandas

Browse files
snowflake_nba.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/database/connection.py CHANGED
@@ -1,112 +1,71 @@
1
- """Database connection management with error handling."""
2
 
3
  import logging
4
  from collections.abc import Generator
5
  from contextlib import contextmanager
6
- from typing import Any
7
 
8
- import snowflake.connector
9
  import streamlit as st
10
- from snowflake.connector import SnowflakeConnection
11
- from snowflake.connector.errors import DatabaseError, ProgrammingError
12
 
13
  logger = logging.getLogger("streamlit_nba")
14
 
 
 
15
 
16
  class DatabaseConnectionError(Exception):
17
- """Raised when database connection fails."""
18
 
19
  pass
20
 
21
 
22
  class QueryExecutionError(Exception):
23
- """Raised when query execution fails."""
24
 
25
  pass
26
 
27
 
28
- @st.cache_resource
29
- def _get_connection_pool() -> SnowflakeConnection:
30
- """Create and cache a Snowflake connection.
31
 
32
  Returns:
33
- Cached Snowflake connection
34
 
35
  Raises:
36
- DatabaseConnectionError: If connection cannot be established
37
  """
 
 
 
 
38
  try:
39
- return snowflake.connector.connect(**st.secrets["snowflake"])
40
- except DatabaseError as e:
41
- logger.error(f"Failed to connect to database: {e}")
42
- raise DatabaseConnectionError(f"Could not connect to database: {e}") from e
43
- except KeyError as e:
44
- logger.error("Snowflake credentials not found in secrets")
45
- raise DatabaseConnectionError(
46
- "Database credentials not configured. Please check st.secrets."
47
- ) from e
48
 
49
 
50
  @contextmanager
51
- def get_connection() -> Generator[SnowflakeConnection, None, None]:
52
- """Context manager for database connections with error handling.
53
 
54
  Yields:
55
- Active Snowflake connection
56
 
57
  Raises:
58
- DatabaseConnectionError: If connection fails
59
-
60
- Example:
61
- with get_connection() as conn:
62
- # use connection
63
  """
64
  try:
65
- conn = snowflake.connector.connect(**st.secrets["snowflake"])
66
- yield conn
67
- except DatabaseError as e:
68
- logger.error(f"Database connection error: {e}")
69
- raise DatabaseConnectionError(f"Database connection failed: {e}") from e
70
- except KeyError as e:
71
- logger.error("Snowflake credentials not found in secrets")
72
- raise DatabaseConnectionError(
73
- "Database credentials not configured. Please check st.secrets."
74
- ) from e
75
  finally:
76
- try:
77
- conn.close()
78
- except Exception:
79
- pass # Connection may already be closed
80
-
81
-
82
- def execute_query(
83
- conn: SnowflakeConnection,
84
- query: str,
85
- params: tuple[Any, ...] | list[Any] | None = None,
86
- ) -> list[tuple[Any, ...]]:
87
- """Execute a parameterized query safely.
88
-
89
- Args:
90
- conn: Active database connection
91
- query: SQL query with %s placeholders
92
- params: Query parameters (optional)
93
-
94
- Returns:
95
- List of result tuples
96
-
97
- Raises:
98
- QueryExecutionError: If query execution fails
99
- """
100
- try:
101
- with conn.cursor() as cur:
102
- if params:
103
- cur.execute(query, params)
104
- else:
105
- cur.execute(query)
106
- return cur.fetchall()
107
- except ProgrammingError as e:
108
- logger.error(f"Query execution error: {e}")
109
- raise QueryExecutionError(f"Query failed: {e}") from e
110
- except DatabaseError as e:
111
- logger.error(f"Database error during query: {e}")
112
- raise QueryExecutionError(f"Database error: {e}") from e
 
1
+ """Local CSV data management with error handling."""
2
 
3
  import logging
4
  from collections.abc import Generator
5
  from contextlib import contextmanager
6
+ from pathlib import Path
7
 
8
+ import pandas as pd
9
  import streamlit as st
 
 
10
 
11
  logger = logging.getLogger("streamlit_nba")
12
 
13
+ CSV_PATH = Path("snowflake_nba.csv")
14
+
15
 
16
  class DatabaseConnectionError(Exception):
17
+ """Raised when local data file cannot be found or loaded."""
18
 
19
  pass
20
 
21
 
22
  class QueryExecutionError(Exception):
23
+ """Raised when data query fails."""
24
 
25
  pass
26
 
27
 
28
+ @st.cache_data
29
+ def load_data() -> pd.DataFrame:
30
+ """Load and cache the local CSV data.
31
 
32
  Returns:
33
+ DataFrame containing player data
34
 
35
  Raises:
36
+ DatabaseConnectionError: If file cannot be loaded
37
  """
38
+ if not CSV_PATH.exists():
39
+ logger.error(f"Data file not found: {CSV_PATH}")
40
+ raise DatabaseConnectionError(f"Data file not found: {CSV_PATH}")
41
+
42
  try:
43
+ df = pd.read_csv(CSV_PATH)
44
+ # Ensure column names match expected Snowflake names (uppercase)
45
+ df.columns = [col.upper() for col in df.columns]
46
+ return df
47
+ except Exception as e:
48
+ logger.error(f"Failed to load CSV data: {e}")
49
+ raise DatabaseConnectionError(f"Could not load data from {CSV_PATH}: {e}") from e
 
 
50
 
51
 
52
  @contextmanager
53
+ def get_connection() -> Generator[pd.DataFrame, None, None]:
54
+ """Context manager for local data access with error handling.
55
 
56
  Yields:
57
+ DataFrame with player data
58
 
59
  Raises:
60
+ DatabaseConnectionError: If data cannot be loaded
 
 
 
 
61
  """
62
  try:
63
+ yield load_data()
64
+ except DatabaseConnectionError as e:
65
+ logger.error(f"Data access error: {e}")
66
+ raise
67
+ except Exception as e:
68
+ logger.error(f"Unexpected error accessing data: {e}")
69
+ raise DatabaseConnectionError(f"Data access failed: {e}") from e
 
 
 
70
  finally:
71
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/database/queries.py CHANGED
@@ -1,64 +1,61 @@
1
- """Parameterized database queries for player data."""
2
 
3
  import logging
4
  from typing import Any
5
 
6
  import pandas as pd
7
- from snowflake.connector import SnowflakeConnection
8
 
9
  from src.config import MAX_QUERY_ATTEMPTS, PLAYER_COLUMNS
10
- from src.database.connection import QueryExecutionError, execute_query
11
 
12
  logger = logging.getLogger("streamlit_nba")
13
 
14
 
15
- def search_player_by_name(conn: SnowflakeConnection, name: str) -> list[tuple[str]]:
16
  """Search for players by name (first, last, or full name).
17
 
18
  Args:
19
- conn: Active database connection
20
  name: Search term (case-insensitive)
21
 
22
  Returns:
23
  List of tuples containing matching full names
24
  """
25
  name_lower = name.lower().strip()
26
- query = """
27
- SELECT full_name FROM NBA
28
- WHERE full_name_lower = %s
29
- OR first_name_lower = %s
30
- OR last_name_lower = %s
31
- """
32
- return execute_query(conn, query, (name_lower, name_lower, name_lower))
33
 
34
 
35
  def get_player_by_full_name(
36
- conn: SnowflakeConnection, full_name: str
37
  ) -> tuple[Any, ...] | None:
38
  """Get a single player's full record by exact name match.
39
 
40
  Args:
41
- conn: Active database connection
42
  full_name: Exact full name of player
43
 
44
  Returns:
45
  Player data tuple or None if not found
46
  """
47
- query = "SELECT * FROM NBA WHERE FULL_NAME = %s"
48
- results = execute_query(conn, query, (full_name,))
49
- return results[0] if results else None
 
50
 
51
 
52
  def get_players_by_full_names(
53
- conn: SnowflakeConnection, names: list[str]
54
  ) -> pd.DataFrame:
55
  """Get multiple players' records in a single batch query.
56
 
57
- This fixes the N+1 query problem by using a single IN clause
58
- instead of multiple individual queries.
59
-
60
  Args:
61
- conn: Active database connection
62
  names: List of exact full names
63
 
64
  Returns:
@@ -67,16 +64,11 @@ def get_players_by_full_names(
67
  if not names:
68
  return pd.DataFrame(columns=PLAYER_COLUMNS)
69
 
70
- # Build parameterized IN clause
71
- placeholders = ", ".join(["%s"] * len(names))
72
- query = f"SELECT * FROM NBA WHERE FULL_NAME IN ({placeholders})"
73
-
74
- results = execute_query(conn, query, tuple(names))
75
- return pd.DataFrame(results, columns=PLAYER_COLUMNS)
76
 
77
 
78
  def get_away_team_by_stats(
79
- conn: SnowflakeConnection,
80
  pts_threshold: int,
81
  reb_threshold: int,
82
  ast_threshold: int,
@@ -85,11 +77,10 @@ def get_away_team_by_stats(
85
  ) -> pd.DataFrame:
86
  """Get a random away team based on stat thresholds.
87
 
88
- Uses UNION with SAMPLE to get diverse players meeting stat criteria.
89
- Includes a max_attempts guard to prevent infinite loops.
90
 
91
  Args:
92
- conn: Active database connection
93
  pts_threshold: Minimum career points
94
  reb_threshold: Minimum career rebounds
95
  ast_threshold: Minimum career assists
@@ -100,28 +91,26 @@ def get_away_team_by_stats(
100
  DataFrame with 5 players
101
 
102
  Raises:
103
- QueryExecutionError: If unable to get 5 players within max_attempts
104
- """
105
- query = """
106
- SELECT * FROM (SELECT * FROM NBA WHERE PTS > %s) SAMPLE (2 ROWS)
107
- UNION
108
- SELECT * FROM (SELECT * FROM NBA WHERE REB > %s) SAMPLE (1 ROWS)
109
- UNION
110
- SELECT * FROM (SELECT * FROM NBA WHERE AST > %s) SAMPLE (1 ROWS)
111
- UNION
112
- SELECT * FROM (SELECT * FROM NBA WHERE STL > %s) SAMPLE (1 ROWS)
113
  """
114
- params = (pts_threshold, reb_threshold, ast_threshold, stl_threshold)
115
-
116
  for attempt in range(max_attempts):
117
- results = execute_query(conn, query, params)
118
- if len(results) == 5:
119
- logger.info(f"Got away team on attempt {attempt + 1}")
120
- return pd.DataFrame(results, columns=PLAYER_COLUMNS)
121
- logger.debug(f"Attempt {attempt + 1}: got {len(results)} players, need 5")
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Fallback: if we can't get exactly 5, raise an error
124
  raise QueryExecutionError(
125
  f"Could not generate away team with 5 players after {max_attempts} attempts. "
126
- f"Last attempt returned {len(results)} players."
127
  )
 
1
+ """Local data queries using pandas on loaded CSV data."""
2
 
3
  import logging
4
  from typing import Any
5
 
6
  import pandas as pd
 
7
 
8
  from src.config import MAX_QUERY_ATTEMPTS, PLAYER_COLUMNS
9
+ from src.database.connection import QueryExecutionError
10
 
11
  logger = logging.getLogger("streamlit_nba")
12
 
13
 
14
+ def search_player_by_name(df: pd.DataFrame, name: str) -> list[tuple[str]]:
15
  """Search for players by name (first, last, or full name).
16
 
17
  Args:
18
+ df: Player DataFrame
19
  name: Search term (case-insensitive)
20
 
21
  Returns:
22
  List of tuples containing matching full names
23
  """
24
  name_lower = name.lower().strip()
25
+ mask = (
26
+ (df["FULL_NAME_LOWER"] == name_lower)
27
+ | (df["FIRST_NAME_LOWER"] == name_lower)
28
+ | (df["LAST_NAME_LOWER"] == name_lower)
29
+ )
30
+ results = df[mask]["FULL_NAME"].unique().tolist()
31
+ return [(name,) for name in results]
32
 
33
 
34
  def get_player_by_full_name(
35
+ df: pd.DataFrame, full_name: str
36
  ) -> tuple[Any, ...] | None:
37
  """Get a single player's full record by exact name match.
38
 
39
  Args:
40
+ df: Player DataFrame
41
  full_name: Exact full name of player
42
 
43
  Returns:
44
  Player data tuple or None if not found
45
  """
46
+ result = df[df["FULL_NAME"] == full_name]
47
+ if result.empty:
48
+ return None
49
+ return tuple(result.iloc[0].values)
50
 
51
 
52
  def get_players_by_full_names(
53
+ df: pd.DataFrame, names: list[str]
54
  ) -> pd.DataFrame:
55
  """Get multiple players' records in a single batch query.
56
 
 
 
 
57
  Args:
58
+ df: Player DataFrame
59
  names: List of exact full names
60
 
61
  Returns:
 
64
  if not names:
65
  return pd.DataFrame(columns=PLAYER_COLUMNS)
66
 
67
+ return df[df["FULL_NAME"].isin(names)]
 
 
 
 
 
68
 
69
 
70
  def get_away_team_by_stats(
71
+ df: pd.DataFrame,
72
  pts_threshold: int,
73
  reb_threshold: int,
74
  ast_threshold: int,
 
77
  ) -> pd.DataFrame:
78
  """Get a random away team based on stat thresholds.
79
 
80
+ Replicates Snowflake's SAMPLE and UNION logic using pandas.
 
81
 
82
  Args:
83
+ df: Player DataFrame
84
  pts_threshold: Minimum career points
85
  reb_threshold: Minimum career rebounds
86
  ast_threshold: Minimum career assists
 
91
  DataFrame with 5 players
92
 
93
  Raises:
94
+ RuntimeError: If unable to get 5 players within max_attempts
 
 
 
 
 
 
 
 
 
95
  """
 
 
96
  for attempt in range(max_attempts):
97
+ try:
98
+ df1 = df[df["PTS"] > pts_threshold].sample(n=2)
99
+ df2 = df[df["REB"] > reb_threshold].sample(n=1)
100
+ df3 = df[df["AST"] > ast_threshold].sample(n=1)
101
+ df4 = df[df["STL"] > stl_threshold].sample(n=1)
102
+
103
+ results = pd.concat([df1, df2, df3, df4]).drop_duplicates()
104
+
105
+ if len(results) == 5:
106
+ logger.info(f"Got away team on attempt {attempt + 1}")
107
+ return results
108
+ except ValueError:
109
+ # sample() can raise ValueError if n > population
110
+ logger.debug(f"Attempt {attempt + 1}: stat thresholds too restrictive")
111
+ continue
112
 
 
113
  raise QueryExecutionError(
114
  f"Could not generate away team with 5 players after {max_attempts} attempts. "
115
+ "Try lowering the difficulty."
116
  )
tests/conftest.py CHANGED
@@ -1,26 +1,10 @@
1
  """Pytest fixtures for NBA Streamlit application tests."""
2
 
3
  from typing import Any
4
- from unittest.mock import MagicMock
5
-
6
  import pandas as pd
7
  import pytest
8
 
9
 
10
- @pytest.fixture
11
- def mock_snowflake_connection() -> MagicMock:
12
- """Create a mock Snowflake connection.
13
-
14
- Returns:
15
- Mock connection object that simulates Snowflake connection behavior
16
- """
17
- mock_conn = MagicMock()
18
- mock_cursor = MagicMock()
19
- mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
20
- mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
21
- return mock_conn
22
-
23
-
24
  @pytest.fixture
25
  def sample_player_data() -> list[tuple[Any, ...]]:
26
  """Create sample player data matching database schema.
 
1
  """Pytest fixtures for NBA Streamlit application tests."""
2
 
3
  from typing import Any
 
 
4
  import pandas as pd
5
  import pytest
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @pytest.fixture
9
  def sample_player_data() -> list[tuple[Any, ...]]:
10
  """Create sample player data matching database schema.
tests/test_database.py CHANGED
@@ -1,6 +1,4 @@
1
- """Tests for database module."""
2
-
3
- from unittest.mock import MagicMock
4
 
5
  import pandas as pd
6
  import pytest
@@ -17,130 +15,72 @@ from src.database.queries import (
17
  class TestSearchPlayerByName:
18
  """Tests for search_player_by_name function."""
19
 
20
- def test_uses_parameterized_query(
21
- self, mock_snowflake_connection: MagicMock
22
- ) -> None:
23
- """Verify parameterized queries are used (not string formatting)."""
24
- mock_cursor = MagicMock()
25
- mock_cursor.fetchall.return_value = [("LeBron James",)]
26
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
27
- mock_cursor
28
- )
29
 
30
- search_player_by_name(mock_snowflake_connection, "james")
31
-
32
- # Verify execute was called with params tuple, not string formatting
33
- mock_cursor.execute.assert_called_once()
34
- call_args = mock_cursor.execute.call_args
35
- query = call_args[0][0]
36
- params = call_args[0][1]
37
-
38
- # Query should use %s placeholders
39
- assert "%s" in query
40
- # Should not contain the actual search term in the query string
41
- assert "james" not in query.lower()
42
- # Params should be a tuple with the search term
43
- assert params == ("james", "james", "james")
44
-
45
- def test_returns_list_of_tuples(
46
- self, mock_snowflake_connection: MagicMock
47
- ) -> None:
48
- """Test that results are returned as list of tuples."""
49
- mock_cursor = MagicMock()
50
- mock_cursor.fetchall.return_value = [
51
- ("LeBron James",),
52
- ("James Harden",),
53
- ]
54
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
55
- mock_cursor
56
- )
57
 
58
- result = search_player_by_name(mock_snowflake_connection, "james")
 
 
 
59
 
60
- assert result == [("LeBron James",), ("James Harden",)]
 
 
 
61
 
62
 
63
  class TestGetPlayersByFullNames:
64
  """Tests for get_players_by_full_names batch query."""
65
 
66
- def test_single_query_for_multiple_names(
67
- self, mock_snowflake_connection: MagicMock, sample_player_data: list
68
- ) -> None:
69
- """Verify batch query uses single IN clause instead of N queries."""
70
- mock_cursor = MagicMock()
71
- mock_cursor.fetchall.return_value = sample_player_data
72
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
73
- mock_cursor
74
- )
75
-
76
  names = ["LeBron James", "Michael Jordan"]
77
- get_players_by_full_names(mock_snowflake_connection, names)
78
-
79
- # Should only execute one query
80
- assert mock_cursor.execute.call_count == 1
81
-
82
- call_args = mock_cursor.execute.call_args
83
- query = call_args[0][0]
84
- params = call_args[0][1]
85
-
86
- # Query should have IN clause with placeholders
87
- assert "IN" in query.upper()
88
- assert "%s" in query
89
- # Params should be tuple of names
90
- assert params == ("LeBron James", "Michael Jordan")
91
-
92
- def test_returns_dataframe(
93
- self, mock_snowflake_connection: MagicMock, sample_player_data: list
94
- ) -> None:
95
- """Test that results are returned as DataFrame."""
96
- mock_cursor = MagicMock()
97
- mock_cursor.fetchall.return_value = sample_player_data
98
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
99
- mock_cursor
100
- )
101
-
102
- result = get_players_by_full_names(
103
- mock_snowflake_connection, ["LeBron James", "Michael Jordan"]
104
- )
105
 
106
  assert isinstance(result, pd.DataFrame)
107
- assert list(result.columns) == PLAYER_COLUMNS
108
  assert len(result) == 2
 
 
109
 
110
- def test_empty_names_returns_empty_dataframe(
111
- self, mock_snowflake_connection: MagicMock
112
- ) -> None:
113
- """Test that empty input returns empty DataFrame without query."""
114
- mock_cursor = MagicMock()
115
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
116
- mock_cursor
117
- )
118
-
119
- result = get_players_by_full_names(mock_snowflake_connection, [])
120
 
121
  assert isinstance(result, pd.DataFrame)
122
  assert result.empty
123
- # Should not execute any query
124
- mock_cursor.execute.assert_not_called()
125
 
126
 
127
  class TestGetAwayTeamByStats:
128
- """Tests for get_away_team_by_stats with max_attempts guard."""
129
-
130
- def test_max_attempts_raises_error(
131
- self, mock_snowflake_connection: MagicMock
132
- ) -> None:
133
- """Test that max_attempts limit prevents infinite loop."""
134
- mock_cursor = MagicMock()
135
- # Always return wrong number of players
136
- mock_cursor.fetchall.return_value = [("Player1",), ("Player2",)]
137
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
138
- mock_cursor
139
- )
 
140
 
141
  with pytest.raises(QueryExecutionError) as exc_info:
142
  get_away_team_by_stats(
143
- mock_snowflake_connection,
144
  pts_threshold=1000,
145
  reb_threshold=500,
146
  ast_threshold=300,
@@ -149,28 +89,23 @@ class TestGetAwayTeamByStats:
149
  )
150
 
151
  assert "3 attempts" in str(exc_info.value)
152
- assert mock_cursor.execute.call_count == 3
153
-
154
- def test_success_on_first_try(
155
- self, mock_snowflake_connection: MagicMock, sample_player_data: list
156
- ) -> None:
157
- """Test successful query on first attempt."""
158
- mock_cursor = MagicMock()
159
- # Return exactly 5 players
160
- mock_cursor.fetchall.return_value = sample_player_data * 3 # 6 players
161
- mock_cursor.fetchall.return_value = [
162
- sample_player_data[0],
163
- sample_player_data[1],
164
- sample_player_data[0],
165
- sample_player_data[1],
166
- sample_player_data[0],
167
- ]
168
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
169
- mock_cursor
170
- )
171
 
172
  result = get_away_team_by_stats(
173
- mock_snowflake_connection,
174
  pts_threshold=1000,
175
  reb_threshold=500,
176
  ast_threshold=300,
@@ -179,41 +114,3 @@ class TestGetAwayTeamByStats:
179
 
180
  assert isinstance(result, pd.DataFrame)
181
  assert len(result) == 5
182
- # Should only need one query
183
- assert mock_cursor.execute.call_count == 1
184
-
185
- def test_uses_parameterized_query(
186
- self, mock_snowflake_connection: MagicMock, sample_player_data: list
187
- ) -> None:
188
- """Verify parameterized queries are used for stat thresholds."""
189
- mock_cursor = MagicMock()
190
- mock_cursor.fetchall.return_value = [
191
- sample_player_data[0],
192
- sample_player_data[1],
193
- sample_player_data[0],
194
- sample_player_data[1],
195
- sample_player_data[0],
196
- ]
197
- mock_snowflake_connection.cursor.return_value.__enter__.return_value = (
198
- mock_cursor
199
- )
200
-
201
- get_away_team_by_stats(
202
- mock_snowflake_connection,
203
- pts_threshold=1000,
204
- reb_threshold=500,
205
- ast_threshold=300,
206
- stl_threshold=100,
207
- )
208
-
209
- call_args = mock_cursor.execute.call_args
210
- query = call_args[0][0]
211
- params = call_args[0][1]
212
-
213
- # Query should use %s placeholders
214
- assert "%s" in query
215
- # Should not contain actual numbers in query
216
- assert "1000" not in query
217
- assert "500" not in query
218
- # Params should be tuple of thresholds
219
- assert params == (1000, 500, 300, 100)
 
1
+ """Tests for database module using local pandas data."""
 
 
2
 
3
  import pandas as pd
4
  import pytest
 
15
  class TestSearchPlayerByName:
16
  """Tests for search_player_by_name function."""
17
 
18
+ def test_search_by_full_name(self, sample_player_df: pd.DataFrame) -> None:
19
+ """Verify search finds player by full name."""
20
+ result = search_player_by_name(sample_player_df, "LeBron James")
21
+ assert result == [("LeBron James",)]
 
 
 
 
 
22
 
23
+ def test_search_by_first_name(self, sample_player_df: pd.DataFrame) -> None:
24
+ """Verify search finds player by first name."""
25
+ result = search_player_by_name(sample_player_df, "LeBron")
26
+ assert result == [("LeBron James",)]
27
+
28
+ def test_search_by_last_name(self, sample_player_df: pd.DataFrame) -> None:
29
+ """Verify search finds player by last name."""
30
+ result = search_player_by_name(sample_player_df, "Jordan")
31
+ assert result == [("Michael Jordan",)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def test_search_case_insensitive(self, sample_player_df: pd.DataFrame) -> None:
34
+ """Verify search is case-insensitive."""
35
+ result = search_player_by_name(sample_player_df, "lebron")
36
+ assert result == [("LeBron James",)]
37
 
38
+ def test_returns_empty_on_no_match(self, sample_player_df: pd.DataFrame) -> None:
39
+ """Verify empty list returned when no player found."""
40
+ result = search_player_by_name(sample_player_df, "NonExistent Player")
41
+ assert result == []
42
 
43
 
44
  class TestGetPlayersByFullNames:
45
  """Tests for get_players_by_full_names batch query."""
46
 
47
+ def test_returns_correct_players(self, sample_player_df: pd.DataFrame) -> None:
48
+ """Verify correct players are returned in DataFrame."""
 
 
 
 
 
 
 
 
49
  names = ["LeBron James", "Michael Jordan"]
50
+ result = get_players_by_full_names(sample_player_df, names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  assert isinstance(result, pd.DataFrame)
 
53
  assert len(result) == 2
54
+ assert set(result["FULL_NAME"]) == set(names)
55
+ assert list(result.columns) == PLAYER_COLUMNS
56
 
57
+ def test_empty_names_returns_empty_dataframe(self, sample_player_df: pd.DataFrame) -> None:
58
+ """Test that empty input returns empty DataFrame."""
59
+ result = get_players_by_full_names(sample_player_df, [])
 
 
 
 
 
 
 
60
 
61
  assert isinstance(result, pd.DataFrame)
62
  assert result.empty
63
+ assert list(result.columns) == PLAYER_COLUMNS
 
64
 
65
 
66
  class TestGetAwayTeamByStats:
67
+ """Tests for get_away_team_by_stats."""
68
+
69
+ def test_max_attempts_raises_error(self) -> None:
70
+ """Test that max_attempts limit works when population is too small."""
71
+ # Create a DF with only 2 players
72
+ df = pd.DataFrame([
73
+ {"FULL_NAME": "P1", "PTS": 1001, "REB": 501, "AST": 301, "STL": 101},
74
+ {"FULL_NAME": "P2", "PTS": 1001, "REB": 501, "AST": 301, "STL": 101},
75
+ ])
76
+ # Add missing columns to avoid errors if needed, though queries only use these
77
+ for col in PLAYER_COLUMNS:
78
+ if col not in df.columns:
79
+ df[col] = 0
80
 
81
  with pytest.raises(QueryExecutionError) as exc_info:
82
  get_away_team_by_stats(
83
+ df,
84
  pts_threshold=1000,
85
  reb_threshold=500,
86
  ast_threshold=300,
 
89
  )
90
 
91
  assert "3 attempts" in str(exc_info.value)
92
+
93
+ def test_success_with_enough_players(self) -> None:
94
+ """Test successful generation with sufficient population."""
95
+ # Create a DF with 10 players meeting criteria
96
+ data = []
97
+ for i in range(10):
98
+ data.append({
99
+ "FULL_NAME": f"Player{i}",
100
+ "PTS": 2000, "REB": 1000, "AST": 500, "STL": 200
101
+ })
102
+ df = pd.DataFrame(data)
103
+ for col in PLAYER_COLUMNS:
104
+ if col not in df.columns:
105
+ df[col] = 0
 
 
 
 
 
106
 
107
  result = get_away_team_by_stats(
108
+ df,
109
  pts_threshold=1000,
110
  reb_threshold=500,
111
  ast_threshold=300,
 
114
 
115
  assert isinstance(result, pd.DataFrame)
116
  assert len(result) == 5