| """Tests for sage.data.loader — data loading and temporal split functions.""" |
|
|
| import pandas as pd |
| import pytest |
|
|
| from sage.data.loader import ( |
| create_temporal_splits, |
| filter_5_core, |
| verify_temporal_boundaries, |
| ) |
|
|
|
|
| class TestVerifyTemporalBoundaries: |
| """Tests for verify_temporal_boundaries function.""" |
|
|
| def test_valid_splits_returns_boundaries(self): |
| """Valid temporal splits should return boundary dict.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200, 300]}) |
| val_df = pd.DataFrame({"timestamp": [400, 500]}) |
| test_df = pd.DataFrame({"timestamp": [600, 700, 800]}) |
|
|
| result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| assert result == { |
| "train": (100, 300), |
| "val": (400, 500), |
| "test": (600, 800), |
| } |
|
|
| def test_empty_train_raises_clear_error(self): |
| """Empty train split should raise ValueError with clear message.""" |
| train_df = pd.DataFrame({"timestamp": []}) |
| val_df = pd.DataFrame({"timestamp": [100, 200]}) |
| test_df = pd.DataFrame({"timestamp": [300, 400]}) |
|
|
| with pytest.raises(ValueError, match="Train split is empty"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_empty_val_raises_clear_error(self): |
| """Empty validation split should raise ValueError with clear message.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": []}) |
| test_df = pd.DataFrame({"timestamp": [300, 400]}) |
|
|
| with pytest.raises(ValueError, match="Validation split is empty"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_empty_test_raises_clear_error(self): |
| """Empty test split should raise ValueError with clear message.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": []}) |
|
|
| with pytest.raises(ValueError, match="Test split is empty"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_missing_timestamp_column_train(self): |
| """Missing timestamp column in train should raise ValueError.""" |
| train_df = pd.DataFrame({"other_col": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [500, 600]}) |
|
|
| with pytest.raises(ValueError, match="Train split missing 'timestamp'"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_missing_timestamp_column_val(self): |
| """Missing timestamp column in val should raise ValueError.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"other_col": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [500, 600]}) |
|
|
| with pytest.raises(ValueError, match="Validation split missing 'timestamp'"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_missing_timestamp_column_test(self): |
| """Missing timestamp column in test should raise ValueError.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"other_col": [500, 600]}) |
|
|
| with pytest.raises(ValueError, match="Test split missing 'timestamp'"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_train_val_overlap_raises_error(self): |
| """Train/val overlap (train_max > val_min) should raise ValueError.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200, 500]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [600, 700]}) |
|
|
| with pytest.raises(ValueError, match="Train/val overlap"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_train_val_exact_boundary_raises_error(self): |
| """Train/val exact boundary (train_max == val_min) is temporal leakage.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200, 300]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [500, 600]}) |
|
|
| with pytest.raises(ValueError, match="Train/val overlap"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_val_test_overlap_raises_error(self): |
| """Val/test overlap (val_max > test_min) should raise ValueError.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 600]}) |
| test_df = pd.DataFrame({"timestamp": [500, 700]}) |
|
|
| with pytest.raises(ValueError, match="Val/test overlap"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_val_test_exact_boundary_raises_error(self): |
| """Val/test exact boundary (val_max == test_min) is temporal leakage.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [400, 500]}) |
|
|
| with pytest.raises(ValueError, match="Val/test overlap"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| def test_single_row_splits_valid(self): |
| """Single-row splits with valid boundaries should pass.""" |
| train_df = pd.DataFrame({"timestamp": [100]}) |
| val_df = pd.DataFrame({"timestamp": [200]}) |
| test_df = pd.DataFrame({"timestamp": [300]}) |
|
|
| result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| assert result["train"] == (100, 100) |
| assert result["val"] == (200, 200) |
| assert result["test"] == (300, 300) |
|
|
| def test_verbose_logging(self, caplog): |
| """Verbose mode should log boundary information.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [500, 600]}) |
|
|
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=True) |
|
|
| assert "Temporal boundaries verified" in caplog.text |
|
|
| def test_returns_int_timestamps(self): |
| """Boundary values should be integers, not numpy types.""" |
| train_df = pd.DataFrame({"timestamp": [100, 200]}) |
| val_df = pd.DataFrame({"timestamp": [300, 400]}) |
| test_df = pd.DataFrame({"timestamp": [500, 600]}) |
|
|
| result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| for split_name, (start, end) in result.items(): |
| assert isinstance(start, int), f"{split_name} start is not int" |
| assert isinstance(end, int), f"{split_name} end is not int" |
|
|
| def test_millisecond_timestamps(self): |
| """Should handle millisecond timestamps (real-world format).""" |
| |
| train_df = pd.DataFrame({"timestamp": [1672531200000, 1672617600000]}) |
| val_df = pd.DataFrame({"timestamp": [1685577600000, 1685664000000]}) |
| test_df = pd.DataFrame({"timestamp": [1701388800000, 1701475200000]}) |
|
|
| result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
| assert result["train"][0] == 1672531200000 |
| assert result["test"][1] == 1701475200000 |
|
|
| def test_empty_check_before_column_check(self): |
| """Empty split error should appear before missing column error.""" |
| |
| train_df = pd.DataFrame({"other": []}) |
| val_df = pd.DataFrame({"timestamp": [100]}) |
| test_df = pd.DataFrame({"timestamp": [200]}) |
|
|
| |
| with pytest.raises(ValueError, match="Train split is empty"): |
| verify_temporal_boundaries(train_df, val_df, test_df, verbose=False) |
|
|
|
|
| class TestCreateTemporalSplits: |
| """Tests for create_temporal_splits function.""" |
|
|
| def test_default_ratios_70_10_20(self): |
| """Default ratios should produce 70/10/20 split.""" |
| df = pd.DataFrame( |
| { |
| "timestamp": list(range(100)), |
| "data": list(range(100)), |
| } |
| ) |
|
|
| train, val, test = create_temporal_splits(df, save=False, verbose=False) |
|
|
| assert len(train) == 70 |
| assert len(val) == 10 |
| assert len(test) == 20 |
|
|
| def test_preserves_temporal_order(self): |
| """Train timestamps should precede val, val should precede test.""" |
| df = pd.DataFrame( |
| { |
| "timestamp": [500, 100, 300, 200, 400, 600, 700, 800, 900, 1000], |
| "data": list(range(10)), |
| } |
| ) |
|
|
| train, val, test = create_temporal_splits(df, save=False, verbose=False) |
|
|
| assert train["timestamp"].max() < val["timestamp"].min() |
| assert val["timestamp"].max() < test["timestamp"].min() |
|
|
| def test_custom_ratios(self): |
| """Custom ratios should be respected.""" |
| df = pd.DataFrame( |
| { |
| "timestamp": list(range(100)), |
| "data": list(range(100)), |
| } |
| ) |
|
|
| train, val, test = create_temporal_splits( |
| df, train_ratio=0.5, val_ratio=0.3, save=False, verbose=False |
| ) |
|
|
| assert len(train) == 50 |
| assert len(val) == 30 |
| assert len(test) == 20 |
|
|
| def test_floating_point_bug_fixed(self): |
| """0.7 + 0.1 floating point issue should not lose samples.""" |
| |
| df = pd.DataFrame( |
| { |
| "timestamp": list(range(10)), |
| "data": list(range(10)), |
| } |
| ) |
|
|
| train, val, test = create_temporal_splits(df, save=False, verbose=False) |
|
|
| |
| assert len(train) == 7 |
| assert len(val) == 1 |
| assert len(test) == 2 |
| |
| assert len(train) + len(val) + len(test) == 10 |
|
|
| def test_empty_dataframe_raises_error(self): |
| """Empty DataFrame should raise ValueError.""" |
| df = pd.DataFrame({"timestamp": []}) |
|
|
| with pytest.raises(ValueError, match="DataFrame is empty"): |
| create_temporal_splits(df, save=False, verbose=False) |
|
|
| def test_missing_timestamp_column_raises_error(self): |
| """Missing timestamp column should raise ValueError.""" |
| df = pd.DataFrame({"other_col": [1, 2, 3]}) |
|
|
| with pytest.raises(ValueError, match="missing 'timestamp' column"): |
| create_temporal_splits(df, save=False, verbose=False) |
|
|
| def test_negative_train_ratio_raises_error(self): |
| """Negative train_ratio should raise ValueError.""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| with pytest.raises(ValueError, match="train_ratio must be between 0 and 1"): |
| create_temporal_splits(df, train_ratio=-0.2, save=False, verbose=False) |
|
|
| def test_train_ratio_greater_than_1_raises_error(self): |
| """train_ratio > 1 should raise ValueError.""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| with pytest.raises(ValueError, match="train_ratio must be between 0 and 1"): |
| create_temporal_splits(df, train_ratio=1.5, save=False, verbose=False) |
|
|
| def test_negative_val_ratio_raises_error(self): |
| """Negative val_ratio should raise ValueError.""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| with pytest.raises(ValueError, match="val_ratio must be between 0 and 1"): |
| create_temporal_splits(df, val_ratio=-0.1, save=False, verbose=False) |
|
|
| def test_val_ratio_greater_than_1_raises_error(self): |
| """val_ratio > 1 should raise ValueError.""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| with pytest.raises(ValueError, match="val_ratio must be between 0 and 1"): |
| create_temporal_splits(df, val_ratio=1.5, save=False, verbose=False) |
|
|
| def test_ratios_sum_greater_than_1_raises_error(self): |
| """train_ratio + val_ratio > 1 should raise ValueError.""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| with pytest.raises(ValueError, match="train_ratio \\+ val_ratio must be <= 1"): |
| create_temporal_splits( |
| df, train_ratio=0.8, val_ratio=0.5, save=False, verbose=False |
| ) |
|
|
| def test_ratios_sum_exactly_1_valid(self): |
| """train_ratio + val_ratio = 1 is valid (empty test set).""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| train, val, test = create_temporal_splits( |
| df, train_ratio=0.7, val_ratio=0.3, save=False, verbose=False |
| ) |
|
|
| assert len(train) == 7 |
| assert len(val) == 3 |
| assert len(test) == 0 |
|
|
| def test_zero_train_ratio_valid(self): |
| """train_ratio=0 is valid (all data in val/test).""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| train, val, test = create_temporal_splits( |
| df, train_ratio=0.0, val_ratio=0.5, save=False, verbose=False |
| ) |
|
|
| assert len(train) == 0 |
| assert len(val) == 5 |
| assert len(test) == 5 |
|
|
| def test_zero_val_ratio_valid(self): |
| """val_ratio=0 is valid (no validation set).""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| train, val, test = create_temporal_splits( |
| df, train_ratio=0.7, val_ratio=0.0, save=False, verbose=False |
| ) |
|
|
| assert len(train) == 7 |
| assert len(val) == 0 |
| assert len(test) == 3 |
|
|
| def test_warns_on_empty_val_split(self, caplog): |
| """Should warn when validation split ends up empty.""" |
| |
| df = pd.DataFrame({"timestamp": [1, 2, 3]}) |
|
|
| create_temporal_splits( |
| df, train_ratio=0.9, val_ratio=0.05, save=False, verbose=False |
| ) |
|
|
| assert "Validation split is empty" in caplog.text |
|
|
| def test_warns_on_empty_test_split(self, caplog): |
| """Should warn when test split ends up empty.""" |
| df = pd.DataFrame({"timestamp": list(range(10))}) |
|
|
| create_temporal_splits( |
| df, train_ratio=0.7, val_ratio=0.3, save=False, verbose=False |
| ) |
|
|
| assert "Test split is empty" in caplog.text |
|
|
| def test_saves_to_disk_when_requested(self, tmp_path, monkeypatch): |
| """Should save splits to parquet files when save=True.""" |
| monkeypatch.setattr("sage.data.loader.SPLITS_DIR", tmp_path) |
|
|
| df = pd.DataFrame( |
| { |
| "timestamp": list(range(10)), |
| "data": list(range(10)), |
| } |
| ) |
|
|
| create_temporal_splits(df, save=True, verbose=False) |
|
|
| assert (tmp_path / "train.parquet").exists() |
| assert (tmp_path / "val.parquet").exists() |
| assert (tmp_path / "test.parquet").exists() |
|
|
| def test_verbose_logs_split_sizes(self, caplog): |
| """Verbose mode should log split sizes.""" |
| df = pd.DataFrame({"timestamp": list(range(100))}) |
|
|
| create_temporal_splits(df, save=False, verbose=True) |
|
|
| assert "Train:" in caplog.text |
| assert "Val:" in caplog.text |
| assert "Test:" in caplog.text |
|
|
| def test_all_data_preserved(self): |
| """Total rows across splits should equal original.""" |
| df = pd.DataFrame( |
| { |
| "timestamp": list(range(1000)), |
| "data": list(range(1000)), |
| } |
| ) |
|
|
| train, val, test = create_temporal_splits(df, save=False, verbose=False) |
|
|
| assert len(train) + len(val) + len(test) == len(df) |
|
|
| def test_no_data_leakage_across_splits(self): |
| """No row should appear in multiple splits.""" |
| df = pd.DataFrame( |
| { |
| "timestamp": list(range(100)), |
| "id": [f"row_{i}" for i in range(100)], |
| } |
| ) |
|
|
| train, val, test = create_temporal_splits(df, save=False, verbose=False) |
|
|
| train_ids = set(train["id"]) |
| val_ids = set(val["id"]) |
| test_ids = set(test["id"]) |
|
|
| assert train_ids.isdisjoint(val_ids) |
| assert train_ids.isdisjoint(test_ids) |
| assert val_ids.isdisjoint(test_ids) |
|
|
|
|
| class TestFilter5Core: |
| """Tests for filter_5_core function.""" |
|
|
| def test_basic_filtering(self): |
| """Users and items with < min_interactions are removed.""" |
| |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 10 + ["u2"] * 3 + ["u3"] * 7, |
| "parent_asin": ["p1"] * 8 + ["p2"] * 2 + ["p3"] * 10, |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| |
| assert "u2" not in result["user_id"].values |
| assert "p2" not in result["parent_asin"].values |
|
|
| def test_convergence_required(self): |
| """Filtering iterates until no more removals possible.""" |
| |
| |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 5 + ["u2"] * 10, |
| "parent_asin": ["p1"] * 2 + ["p2"] * 3 + ["p1"] * 10, |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| |
| assert "u1" not in result["user_id"].values |
| assert "p2" not in result["parent_asin"].values |
|
|
| def test_empty_input_returns_empty(self): |
| """Empty DataFrame returns empty DataFrame.""" |
| df = pd.DataFrame({"user_id": [], "parent_asin": []}) |
|
|
| result = filter_5_core(df) |
|
|
| assert result.empty |
|
|
| def test_all_filtered_out_returns_empty(self): |
| """When all users/items have < min_interactions, returns empty.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1", "u2", "u3"], |
| "parent_asin": ["p1", "p2", "p3"], |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| assert result.empty |
|
|
| def test_min_interactions_1_keeps_all(self): |
| """min_interactions=1 keeps all data.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1", "u2", "u3"], |
| "parent_asin": ["p1", "p2", "p3"], |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=1) |
|
|
| assert len(result) == len(df) |
|
|
| def test_preserves_other_columns(self): |
| """Non-filter columns are preserved.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 10, |
| "parent_asin": ["p1"] * 10, |
| "rating": [4.0] * 10, |
| "text": ["review"] * 10, |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| assert "rating" in result.columns |
| assert "text" in result.columns |
| assert result["rating"].iloc[0] == 4.0 |
|
|
| def test_resets_index(self): |
| """Result has clean 0-based index.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 10, |
| "parent_asin": ["p1"] * 10, |
| } |
| ) |
| df.index = range(100, 110) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| assert list(result.index) == list(range(len(result))) |
|
|
| def test_logs_retention_stats(self, caplog): |
| """Should log retention percentage.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 10 + ["u2"] * 10, |
| "parent_asin": ["p1"] * 10 + ["p2"] * 10, |
| } |
| ) |
|
|
| filter_5_core(df, min_interactions=5) |
|
|
| assert "retained" in caplog.text |
| assert "%" in caplog.text |
|
|
| def test_logs_warning_when_all_filtered(self, caplog): |
| """Should warn when all data is filtered out.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1", "u2", "u3"], |
| "parent_asin": ["p1", "p2", "p3"], |
| } |
| ) |
|
|
| filter_5_core(df, min_interactions=5) |
|
|
| assert "All data filtered out" in caplog.text |
|
|
| def test_exact_threshold_kept(self): |
| """Users/items with exactly min_interactions are kept.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 5 + ["u2"] * 5, |
| "parent_asin": ["p1"] * 5 + ["p2"] * 5, |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| assert len(result) == 10 |
| assert set(result["user_id"]) == {"u1", "u2"} |
| assert set(result["parent_asin"]) == {"p1", "p2"} |
|
|
| def test_large_min_interactions(self): |
| """Large min_interactions filters aggressively.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 50 + ["u2"] * 10, |
| "parent_asin": ["p1"] * 30 + ["p2"] * 30, |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=20) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| assert len(result) > 0 |
|
|
| def test_handles_single_user(self): |
| """Single user with enough items works.""" |
| df = pd.DataFrame( |
| { |
| "user_id": ["u1"] * 10, |
| "parent_asin": [f"p{i}" for i in range(10)], |
| } |
| ) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| |
| |
| assert result.empty |
|
|
| def test_dense_interaction_matrix(self): |
| """Dense data where everyone interacts with everything.""" |
| users = ["u1", "u2", "u3"] |
| items = ["p1", "p2", "p3"] |
| |
| data = [] |
| for u in users: |
| for p in items: |
| data.extend([{"user_id": u, "parent_asin": p}] * 2) |
|
|
| df = pd.DataFrame(data) |
|
|
| result = filter_5_core(df, min_interactions=5) |
|
|
| |
| assert len(result) == len(df) |
|
|