" - f"{safe_text}
", + f"{safe_text}
", unsafe_allow_html=True, ) - - -def safe_styled_text( - text: str, - tag: str = "span", - color: str | None = None, - align: str | None = None, - **styles: str, -) -> str: - """Generate HTML string with escaped text and validated styles. - - Args: - text: Text content (will be escaped) - tag: HTML tag to use - color: Optional CSS color - align: Optional CSS text-align - **styles: Additional CSS properties - - Returns: - Safe HTML string - """ - safe_text = escape_html(text) - safe_tag = escape_html(tag) - - style_parts: list[str] = [] - if color: - style_parts.append(f"color: {escape_html(color)}") - if align: - style_parts.append(f"text-align: {escape_html(align)}") - for prop, value in styles.items(): - # Convert underscores to hyphens for CSS properties - css_prop = prop.replace("_", "-") - style_parts.append(f"{escape_html(css_prop)}: {escape_html(value)}") - - style_str = "; ".join(style_parts) - if style_str: - return f"<{safe_tag} style='{style_str}'>{safe_text}{safe_tag}>" - return f"<{safe_tag}>{safe_text}{safe_tag}>" diff --git a/src/validation/inputs.py b/src/validation/inputs.py index bcfd32a4ce682aa6505a1e4bc38c087a73d0e6b5..ac76387d286ac309b0c503112a48787666012adf 100644 --- a/src/validation/inputs.py +++ b/src/validation/inputs.py @@ -4,30 +4,6 @@ import re from pydantic import BaseModel, Field, field_validator -# Patterns that indicate SQL injection attempts -SQL_INJECTION_PATTERNS: list[str] = [ - r'[";]', # Double quotes and semicolons (apostrophes allowed for names like O'Neal) - r"--", # SQL comment - r"/\*", # Block comment start - r"\*/", # Block comment end - r"\bUNION\b", # UNION keyword - r"\bSELECT\b", # SELECT keyword - r"\bINSERT\b", # INSERT keyword - r"\bUPDATE\b", # UPDATE keyword - r"\bDELETE\b", # DELETE keyword - r"\bDROP\b", # DROP keyword - r"\bEXEC\b", # EXEC keyword - r"\bOR\s+\d+=\d+", # OR 1=1 pattern - r"\bAND\s+\d+=\d+", # AND 1=1 pattern - r"'\s*OR\s", # ' OR pattern (SQL injection) - r"'\s*AND\s", # ' AND pattern (SQL injection) -] - -# Compiled regex for efficiency -SQL_INJECTION_REGEX = re.compile( - "|".join(SQL_INJECTION_PATTERNS), re.IGNORECASE -) - class PlayerSearchInput(BaseModel): """Validated player search input.""" @@ -39,27 +15,6 @@ class PlayerSearchInput(BaseModel): description="Player name search term", ) - @field_validator("search_term") - @classmethod - def validate_no_sql_injection(cls, v: str) -> str: - """Reject inputs containing SQL injection patterns. - - Args: - v: Input search term - - Returns: - Validated search term - - Raises: - ValueError: If SQL injection pattern detected - """ - if SQL_INJECTION_REGEX.search(v): - raise ValueError( - "Invalid characters in search term. " - "Please use only letters, numbers, spaces, and hyphens." - ) - return v.strip() - @field_validator("search_term") @classmethod def validate_reasonable_characters(cls, v: str) -> str: @@ -69,14 +24,17 @@ class PlayerSearchInput(BaseModel): v: Input search term Returns: - Validated search term + Validated and stripped search term Raises: ValueError: If invalid characters found """ + v = v.strip() + if not v: + raise ValueError("Search term cannot be empty.") # Allow letters, numbers, spaces, hyphens, periods, and apostrophes # (e.g., "O'Neal", "J.R. Smith") - if not re.match(r"^[a-zA-Z0-9\s\-.']+$", v): + if not re.match(r"^[a-zA-Z0-9 \-.']+$", v): raise ValueError( "Search term contains invalid characters. " "Please use only letters, numbers, spaces, hyphens, " diff --git a/tests/test_database.py b/tests/test_database.py index 8c9343fd43d61dd7c91c2e4bc9b75b64228f7bc7..482a6d2bfeb64fb4ba95d423c39a79ceafc4f407 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,10 +1,17 @@ """Tests for database module using local pandas data.""" +from unittest.mock import patch + import pandas as pd import pytest from src.config import PLAYER_COLUMNS -from src.database.connection import QueryExecutionError +from src.database.connection import ( + DatabaseConnectionError, + QueryExecutionError, + get_data, + load_data, +) from src.database.queries import ( get_away_team_by_stats, get_players_by_full_names, @@ -12,6 +19,45 @@ from src.database.queries import ( ) +class TestLoadData: + """Tests for load_data and get_data functions.""" + + def test_load_data_returns_dataframe(self) -> None: + """Test that load_data returns a DataFrame with uppercase columns.""" + df = load_data() + assert isinstance(df, pd.DataFrame) + assert not df.empty + # All columns should be uppercase + for col in df.columns: + assert col == col.upper() + + def test_get_data_returns_dataframe(self) -> None: + """Test that get_data returns a DataFrame.""" + df = get_data() + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @patch("src.database.connection.CSV_PATH") + def test_load_data_missing_file_raises_error(self, mock_path) -> None: # type: ignore[no-untyped-def] + """Test that missing CSV raises DatabaseConnectionError.""" + mock_path.exists.return_value = False + with pytest.raises(DatabaseConnectionError, match="not found"): + load_data() + + @patch("src.database.connection.pd.read_csv") + @patch("src.database.connection.CSV_PATH") + def test_load_data_parser_error_raises_connection_error( + self, + mock_path, + mock_read_csv, # type: ignore[no-untyped-def] + ) -> None: + """Test that CSV parse errors raise DatabaseConnectionError.""" + mock_path.exists.return_value = True + mock_read_csv.side_effect = pd.errors.ParserError("bad csv") + with pytest.raises(DatabaseConnectionError): + load_data() + + class TestSearchPlayerByName: """Tests for search_player_by_name function.""" @@ -76,10 +122,12 @@ class TestGetAwayTeamByStats: def test_max_attempts_raises_error(self) -> None: """Test that max_attempts limit works when population is too small.""" # Create a DF with only 2 players - df = pd.DataFrame([ - {"FULL_NAME": "P1", "PTS": 1001, "REB": 501, "AST": 301, "STL": 101}, - {"FULL_NAME": "P2", "PTS": 1001, "REB": 501, "AST": 301, "STL": 101}, - ]) + df = pd.DataFrame( + [ + {"FULL_NAME": "P1", "PTS": 1001, "REB": 501, "AST": 301, "STL": 101}, + {"FULL_NAME": "P2", "PTS": 1001, "REB": 501, "AST": 301, "STL": 101}, + ] + ) # Add missing columns to avoid errors if needed, though queries only use these for col in PLAYER_COLUMNS: if col not in df.columns: @@ -102,10 +150,15 @@ class TestGetAwayTeamByStats: # Create a DF with 10 players meeting criteria data = [] for i in range(10): - data.append({ - "FULL_NAME": f"Player{i}", - "PTS": 2000, "REB": 1000, "AST": 500, "STL": 200 - }) + data.append( + { + "FULL_NAME": f"Player{i}", + "PTS": 2000, + "REB": 1000, + "AST": 500, + "STL": 200, + } + ) df = pd.DataFrame(data) for col in PLAYER_COLUMNS: if col not in df.columns: @@ -121,3 +174,13 @@ class TestGetAwayTeamByStats: assert isinstance(result, pd.DataFrame) assert len(result) == 5 + + +class TestCsvColumnValidation: + """Integration tests validating CSV data matches config.""" + + def test_csv_columns_match_config(self) -> None: + """Verify that actual CSV columns match PLAYER_COLUMNS in config.""" + df = load_data() + assert not df.empty, "CSV file should not be empty" + assert list(df.columns) == PLAYER_COLUMNS diff --git a/tests/test_ml.py b/tests/test_ml.py index 7e952b72bf53beb7514296bade92df96bc920b82..be462576dce53aabc3130d397eb0b3ddf85584d6 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -25,21 +25,54 @@ class TestAnalyzeTeamStats: # Combined has both teams = 100 values assert combined.shape == (1, 100) - def test_combined_contains_both_teams( - self, sample_team_stats: list[list[float]] - ) -> None: + def test_combined_contains_both_teams(self) -> None: """Test that combined array contains both teams' stats.""" - home_stats = [[1.0, 2.0], [3.0, 4.0]] # 2 players, 2 stats each - away_stats = [[5.0, 6.0], [7.0, 8.0]] + home_stats = [[float(i * 10 + j) for j in range(10)] for i in range(5)] + away_stats = [[float(50 + i * 10 + j) for j in range(10)] for i in range(5)] - _home_array, _away_array, combined = analyze_team_stats( - home_stats, away_stats - ) + _home_array, _away_array, combined = analyze_team_stats(home_stats, away_stats) - # Home should be first 4 values, away next 4 - np.testing.assert_array_equal( - combined[0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] - ) + # Combined should have 100 values: 50 home + 50 away + assert combined.shape == (1, 100) + # First value should be home[0][0], last should be away[4][9] + assert combined[0][0] == 0.0 + assert combined[0][50] == 50.0 + + +class TestAnalyzeTeamStatsValidation: + """Tests for input shape validation in analyze_team_stats.""" + + def test_wrong_number_of_home_players_raises_error(self) -> None: + """Test that wrong number of home players raises ValueError.""" + home_stats = [[1.0] * 10 for _ in range(4)] # 4 players instead of 5 + away_stats = [[1.0] * 10 for _ in range(5)] + + with pytest.raises(ValueError, match="Expected 5 players"): + analyze_team_stats(home_stats, away_stats) + + def test_wrong_number_of_away_players_raises_error(self) -> None: + """Test that wrong number of away players raises ValueError.""" + home_stats = [[1.0] * 10 for _ in range(5)] + away_stats = [[1.0] * 10 for _ in range(6)] # 6 players instead of 5 + + with pytest.raises(ValueError, match="Expected 5 players"): + analyze_team_stats(home_stats, away_stats) + + def test_wrong_stat_count_raises_error(self) -> None: + """Test that wrong number of stats per player raises ValueError.""" + home_stats = [[1.0] * 10 for _ in range(5)] + away_stats = [[1.0] * 10 for _ in range(4)] + [[1.0] * 9] # player with 9 stats + + with pytest.raises(ValueError, match="stats, expected 10"): + analyze_team_stats(home_stats, away_stats) + + def test_home_player_wrong_stat_count_raises_error(self) -> None: + """Test that home player with wrong stat count raises ValueError.""" + home_stats = [[1.0] * 9] + [[1.0] * 10 for _ in range(4)] # first player has 9 + away_stats = [[1.0] * 10 for _ in range(5)] + + with pytest.raises(ValueError, match="stats, expected 10"): + analyze_team_stats(home_stats, away_stats) class TestPredictWinner: @@ -63,9 +96,7 @@ class TestPredictWinner: assert prediction in (0, 1) @patch("src.ml.model.get_winner_model") - def test_high_probability_predicts_win( - self, mock_get_model: MagicMock - ) -> None: + def test_high_probability_predicts_win(self, mock_get_model: MagicMock) -> None: """Test that high probability (>0.5) predicts home win (1).""" mock_model = MagicMock() mock_model.predict.return_value = np.array([[0.8]]) @@ -78,9 +109,7 @@ class TestPredictWinner: assert prediction == 1 @patch("src.ml.model.get_winner_model") - def test_low_probability_predicts_loss( - self, mock_get_model: MagicMock - ) -> None: + def test_low_probability_predicts_loss(self, mock_get_model: MagicMock) -> None: """Test that low probability (<0.5) predicts home loss (0).""" mock_model = MagicMock() mock_model.predict.return_value = np.array([[0.3]]) @@ -93,9 +122,7 @@ class TestPredictWinner: assert prediction == 0 @patch("src.ml.model.get_winner_model") - def test_invalid_shape_raises_error( - self, mock_get_model: MagicMock - ) -> None: + def test_invalid_shape_raises_error(self, mock_get_model: MagicMock) -> None: """Test that invalid input shape raises ValueError.""" mock_model = MagicMock() mock_get_model.return_value = mock_model @@ -109,9 +136,7 @@ class TestPredictWinner: assert "Expected input shape (1, 100)" in str(exc_info.value) @patch("src.ml.model.get_winner_model") - def test_model_called_with_verbose_zero( - self, mock_get_model: MagicMock - ) -> None: + def test_model_called_with_verbose_zero(self, mock_get_model: MagicMock) -> None: """Test that model.predict is called with verbose=0.""" mock_model = MagicMock() mock_model.predict.return_value = np.array([[0.5]]) @@ -123,8 +148,24 @@ class TestPredictWinner: mock_model.predict.assert_called_once_with(stats, verbose=0) +class TestLoadRealModel: + """Integration test loading the real model file.""" + + def test_load_real_model(self) -> None: + """Verify real winner.keras loads with expected input/output shape.""" + from src.ml.model import get_winner_model + + model = get_winner_model() + + assert model is not None + # Model expects 100 features (5 players x 10 stats x 2 teams) + assert model.input_shape == (None, 100) + # Binary classification: single sigmoid output + assert model.output_shape == (None, 1) + + class TestGetWinnerModel: - """Tests for get_winner_model caching.""" + """Tests for get_winner_model loading.""" @patch("src.ml.model.load_model") @patch("src.ml.model.Path") @@ -134,9 +175,6 @@ class TestGetWinnerModel: """Test that missing model file raises ModelLoadError.""" from src.ml.model import get_winner_model - # Clear the cache to ensure fresh test - get_winner_model.clear() - mock_path_instance = MagicMock() mock_path_instance.exists.return_value = False mock_path.return_value = mock_path_instance diff --git a/tests/test_models.py b/tests/test_models.py index a22ece80dbd48d0f3804d486eb8f9e72b557cae6..5b8c6418b06246a2d889b6df5587042f0c0be40d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,90 +3,7 @@ import pytest from src.config import DIFFICULTY_PRESETS -from src.models.player import DifficultySettings, PlayerStats - - -class TestPlayerStats: - """Tests for PlayerStats model.""" - - def test_from_db_row(self, sample_player_data: list) -> None: - """Test creating PlayerStats from database row tuple.""" - row = sample_player_data[0] # LeBron James data - - player = PlayerStats.from_db_row(row) - - assert player.full_name == "LeBron James" - assert player.pts == 39223 - assert player.ast == 10141 - assert player.is_active is True - - def test_validates_negative_stats(self) -> None: - """Test that negative stats are rejected.""" - with pytest.raises(ValueError): - PlayerStats( - full_name="Test Player", - ast=-1, # Invalid - blk=0, - dreb=0, - fg3a=0, - fg3m=0, - fg3_pct=0.0, - fga=0, - fgm=0, - fg_pct=0.0, - fta=0, - ftm=0, - ft_pct=0.0, - gp=0, - gs=0, - min=0, - oreb=0, - pf=0, - pts=0, - reb=0, - stl=0, - tov=0, - first_name="Test", - last_name="Player", - full_name_lower="test player", - first_name_lower="test", - last_name_lower="player", - is_active=True, - ) - - def test_validates_percentage_range(self) -> None: - """Test that percentages must be 0-1.""" - with pytest.raises(ValueError): - PlayerStats( - full_name="Test Player", - ast=0, - blk=0, - dreb=0, - fg3a=0, - fg3m=0, - fg3_pct=1.5, # Invalid - over 1.0 - fga=0, - fgm=0, - fg_pct=0.0, - fta=0, - ftm=0, - ft_pct=0.0, - gp=0, - gs=0, - min=0, - oreb=0, - pf=0, - pts=0, - reb=0, - stl=0, - tov=0, - first_name="Test", - last_name="Player", - full_name_lower="test player", - first_name_lower="test", - last_name_lower="player", - is_active=True, - ) +from src.models.player import DifficultySettings class TestDifficultySettings: diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000000000000000000000000000000000000..8f670b7ffee88d1aac731b7a0af5bdb7ce3f0a04 --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,123 @@ +"""Tests for session state management functions.""" + +from unittest.mock import patch + +import pandas as pd + +from src.config import DIFFICULTY_PRESETS +from src.state.session import get_away_stats, get_home_team_df, init_session_state + + +class TestInitSessionState: + """Tests for init_session_state.""" + + def test_initializes_all_expected_keys(self) -> None: + """Verify all default keys are created.""" + state: dict = {} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + init_session_state() + + expected_keys = { + "home_team", + "away_team", + "away_team_df", + "away_stats", + "home_team_df", + "radio_index", + } + assert set(state.keys()) == expected_keys + + def test_does_not_overwrite_existing_values(self) -> None: + """Verify calling init twice does not overwrite existing values.""" + state: dict = {"home_team": ["Player A"]} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + init_session_state() + + assert state["home_team"] == ["Player A"] + + def test_sets_correct_default_away_stats(self) -> None: + """Verify away_stats defaults to Regular difficulty preset.""" + state: dict = {} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + init_session_state() + + assert state["away_stats"] == list(DIFFICULTY_PRESETS["Regular"]) + + def test_sets_empty_dataframes_by_default(self) -> None: + """Verify DataFrames start empty.""" + state: dict = {} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + init_session_state() + + assert isinstance(state["away_team_df"], pd.DataFrame) + assert state["away_team_df"].empty + assert isinstance(state["home_team_df"], pd.DataFrame) + assert state["home_team_df"].empty + + +class TestGetAwayStats: + """Tests for get_away_stats.""" + + def test_returns_stats_from_session(self) -> None: + """Verify returns stats when properly set in session.""" + state: dict = {"away_stats": [100, 200, 300, 400]} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + result = get_away_stats() + + assert result == [100, 200, 300, 400] + + def test_returns_defaults_when_invalid(self) -> None: + """Verify returns defaults when away_stats is invalid.""" + state: dict = {"away_stats": "invalid"} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + result = get_away_stats() + + assert result == list(DIFFICULTY_PRESETS["Regular"]) + + def test_returns_defaults_on_none(self) -> None: + """Verify returns defaults when away_stats is None.""" + state: dict = {"away_stats": None} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + result = get_away_stats() + + assert result == list(DIFFICULTY_PRESETS["Regular"]) + + def test_returns_defaults_on_wrong_length(self) -> None: + """Verify returns defaults when away_stats has wrong length.""" + state: dict = {"away_stats": [1, 2, 3]} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + result = get_away_stats() + + assert result == list(DIFFICULTY_PRESETS["Regular"]) + + +class TestGetHomeTeamDf: + """Tests for get_home_team_df.""" + + def test_returns_dataframe_from_session(self) -> None: + """Verify returns DataFrame when set in session.""" + expected_df = pd.DataFrame({"FULL_NAME": ["Player A"]}) + state: dict = {"home_team_df": expected_df} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + result = get_home_team_df() + + pd.testing.assert_frame_equal(result, expected_df) + + def test_returns_empty_dataframe_when_not_set(self) -> None: + """Verify returns empty DataFrame when not set.""" + state: dict = {} + with patch("src.state.session.st") as mock_st: + mock_st.session_state = state + result = get_home_team_df() + + assert isinstance(result, pd.DataFrame) + assert result.empty diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de295761ab0f14c4cd6e5df994ec2b96d3f674d9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,106 @@ +"""Tests for HTML utility functions.""" + +from unittest.mock import patch + +from src.utils.html import escape_html + + +class TestEscapeHtml: + """Tests for escape_html function.""" + + def test_escapes_angle_brackets(self) -> None: + """Verify < and > are escaped.""" + assert "<" in escape_html("