Sage / tests /test_loader.py
vxa8502's picture
Harden data loader
866804f
"""Tests for sage.data.loader — data loading and temporal split functions."""
import pandas as pd
import pytest
from sage.data.loader import (
create_temporal_splits,
filter_5_core,
verify_temporal_boundaries,
)
class TestVerifyTemporalBoundaries:
"""Tests for verify_temporal_boundaries function."""
def test_valid_splits_returns_boundaries(self):
"""Valid temporal splits should return boundary dict."""
train_df = pd.DataFrame({"timestamp": [100, 200, 300]})
val_df = pd.DataFrame({"timestamp": [400, 500]})
test_df = pd.DataFrame({"timestamp": [600, 700, 800]})
result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
assert result == {
"train": (100, 300),
"val": (400, 500),
"test": (600, 800),
}
def test_empty_train_raises_clear_error(self):
"""Empty train split should raise ValueError with clear message."""
train_df = pd.DataFrame({"timestamp": []})
val_df = pd.DataFrame({"timestamp": [100, 200]})
test_df = pd.DataFrame({"timestamp": [300, 400]})
with pytest.raises(ValueError, match="Train split is empty"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_empty_val_raises_clear_error(self):
"""Empty validation split should raise ValueError with clear message."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": []})
test_df = pd.DataFrame({"timestamp": [300, 400]})
with pytest.raises(ValueError, match="Validation split is empty"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_empty_test_raises_clear_error(self):
"""Empty test split should raise ValueError with clear message."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 400]})
test_df = pd.DataFrame({"timestamp": []})
with pytest.raises(ValueError, match="Test split is empty"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_missing_timestamp_column_train(self):
"""Missing timestamp column in train should raise ValueError."""
train_df = pd.DataFrame({"other_col": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 400]})
test_df = pd.DataFrame({"timestamp": [500, 600]})
with pytest.raises(ValueError, match="Train split missing 'timestamp'"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_missing_timestamp_column_val(self):
"""Missing timestamp column in val should raise ValueError."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"other_col": [300, 400]})
test_df = pd.DataFrame({"timestamp": [500, 600]})
with pytest.raises(ValueError, match="Validation split missing 'timestamp'"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_missing_timestamp_column_test(self):
"""Missing timestamp column in test should raise ValueError."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 400]})
test_df = pd.DataFrame({"other_col": [500, 600]})
with pytest.raises(ValueError, match="Test split missing 'timestamp'"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_train_val_overlap_raises_error(self):
"""Train/val overlap (train_max > val_min) should raise ValueError."""
train_df = pd.DataFrame({"timestamp": [100, 200, 500]}) # max=500
val_df = pd.DataFrame({"timestamp": [300, 400]}) # min=300
test_df = pd.DataFrame({"timestamp": [600, 700]})
with pytest.raises(ValueError, match="Train/val overlap"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_train_val_exact_boundary_raises_error(self):
"""Train/val exact boundary (train_max == val_min) is temporal leakage."""
train_df = pd.DataFrame({"timestamp": [100, 200, 300]}) # max=300
val_df = pd.DataFrame({"timestamp": [300, 400]}) # min=300 (same!)
test_df = pd.DataFrame({"timestamp": [500, 600]})
with pytest.raises(ValueError, match="Train/val overlap"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_val_test_overlap_raises_error(self):
"""Val/test overlap (val_max > test_min) should raise ValueError."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 600]}) # max=600
test_df = pd.DataFrame({"timestamp": [500, 700]}) # min=500
with pytest.raises(ValueError, match="Val/test overlap"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_val_test_exact_boundary_raises_error(self):
"""Val/test exact boundary (val_max == test_min) is temporal leakage."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 400]}) # max=400
test_df = pd.DataFrame({"timestamp": [400, 500]}) # min=400 (same!)
with pytest.raises(ValueError, match="Val/test overlap"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
def test_single_row_splits_valid(self):
"""Single-row splits with valid boundaries should pass."""
train_df = pd.DataFrame({"timestamp": [100]})
val_df = pd.DataFrame({"timestamp": [200]})
test_df = pd.DataFrame({"timestamp": [300]})
result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
assert result["train"] == (100, 100)
assert result["val"] == (200, 200)
assert result["test"] == (300, 300)
def test_verbose_logging(self, caplog):
"""Verbose mode should log boundary information."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 400]})
test_df = pd.DataFrame({"timestamp": [500, 600]})
verify_temporal_boundaries(train_df, val_df, test_df, verbose=True)
assert "Temporal boundaries verified" in caplog.text
def test_returns_int_timestamps(self):
"""Boundary values should be integers, not numpy types."""
train_df = pd.DataFrame({"timestamp": [100, 200]})
val_df = pd.DataFrame({"timestamp": [300, 400]})
test_df = pd.DataFrame({"timestamp": [500, 600]})
result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
for split_name, (start, end) in result.items():
assert isinstance(start, int), f"{split_name} start is not int"
assert isinstance(end, int), f"{split_name} end is not int"
def test_millisecond_timestamps(self):
"""Should handle millisecond timestamps (real-world format)."""
# Real timestamps: 2023-01-01, 2023-06-01, 2023-12-01
train_df = pd.DataFrame({"timestamp": [1672531200000, 1672617600000]})
val_df = pd.DataFrame({"timestamp": [1685577600000, 1685664000000]})
test_df = pd.DataFrame({"timestamp": [1701388800000, 1701475200000]})
result = verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
assert result["train"][0] == 1672531200000
assert result["test"][1] == 1701475200000
def test_empty_check_before_column_check(self):
"""Empty split error should appear before missing column error."""
# Empty df without timestamp column
train_df = pd.DataFrame({"other": []})
val_df = pd.DataFrame({"timestamp": [100]})
test_df = pd.DataFrame({"timestamp": [200]})
# Should raise "empty" not "missing column"
with pytest.raises(ValueError, match="Train split is empty"):
verify_temporal_boundaries(train_df, val_df, test_df, verbose=False)
class TestCreateTemporalSplits:
"""Tests for create_temporal_splits function."""
def test_default_ratios_70_10_20(self):
"""Default ratios should produce 70/10/20 split."""
df = pd.DataFrame(
{
"timestamp": list(range(100)),
"data": list(range(100)),
}
)
train, val, test = create_temporal_splits(df, save=False, verbose=False)
assert len(train) == 70
assert len(val) == 10
assert len(test) == 20
def test_preserves_temporal_order(self):
"""Train timestamps should precede val, val should precede test."""
df = pd.DataFrame(
{
"timestamp": [500, 100, 300, 200, 400, 600, 700, 800, 900, 1000],
"data": list(range(10)),
}
)
train, val, test = create_temporal_splits(df, save=False, verbose=False)
assert train["timestamp"].max() < val["timestamp"].min()
assert val["timestamp"].max() < test["timestamp"].min()
def test_custom_ratios(self):
"""Custom ratios should be respected."""
df = pd.DataFrame(
{
"timestamp": list(range(100)),
"data": list(range(100)),
}
)
train, val, test = create_temporal_splits(
df, train_ratio=0.5, val_ratio=0.3, save=False, verbose=False
)
assert len(train) == 50
assert len(val) == 30
assert len(test) == 20
def test_floating_point_bug_fixed(self):
"""0.7 + 0.1 floating point issue should not lose samples."""
# At n=10, the floating point bug would give val_end=7 instead of 8
df = pd.DataFrame(
{
"timestamp": list(range(10)),
"data": list(range(10)),
}
)
train, val, test = create_temporal_splits(df, save=False, verbose=False)
# With round(), we get correct sizes
assert len(train) == 7
assert len(val) == 1
assert len(test) == 2
# Total should equal original
assert len(train) + len(val) + len(test) == 10
def test_empty_dataframe_raises_error(self):
"""Empty DataFrame should raise ValueError."""
df = pd.DataFrame({"timestamp": []})
with pytest.raises(ValueError, match="DataFrame is empty"):
create_temporal_splits(df, save=False, verbose=False)
def test_missing_timestamp_column_raises_error(self):
"""Missing timestamp column should raise ValueError."""
df = pd.DataFrame({"other_col": [1, 2, 3]})
with pytest.raises(ValueError, match="missing 'timestamp' column"):
create_temporal_splits(df, save=False, verbose=False)
def test_negative_train_ratio_raises_error(self):
"""Negative train_ratio should raise ValueError."""
df = pd.DataFrame({"timestamp": list(range(10))})
with pytest.raises(ValueError, match="train_ratio must be between 0 and 1"):
create_temporal_splits(df, train_ratio=-0.2, save=False, verbose=False)
def test_train_ratio_greater_than_1_raises_error(self):
"""train_ratio > 1 should raise ValueError."""
df = pd.DataFrame({"timestamp": list(range(10))})
with pytest.raises(ValueError, match="train_ratio must be between 0 and 1"):
create_temporal_splits(df, train_ratio=1.5, save=False, verbose=False)
def test_negative_val_ratio_raises_error(self):
"""Negative val_ratio should raise ValueError."""
df = pd.DataFrame({"timestamp": list(range(10))})
with pytest.raises(ValueError, match="val_ratio must be between 0 and 1"):
create_temporal_splits(df, val_ratio=-0.1, save=False, verbose=False)
def test_val_ratio_greater_than_1_raises_error(self):
"""val_ratio > 1 should raise ValueError."""
df = pd.DataFrame({"timestamp": list(range(10))})
with pytest.raises(ValueError, match="val_ratio must be between 0 and 1"):
create_temporal_splits(df, val_ratio=1.5, save=False, verbose=False)
def test_ratios_sum_greater_than_1_raises_error(self):
"""train_ratio + val_ratio > 1 should raise ValueError."""
df = pd.DataFrame({"timestamp": list(range(10))})
with pytest.raises(ValueError, match="train_ratio \\+ val_ratio must be <= 1"):
create_temporal_splits(
df, train_ratio=0.8, val_ratio=0.5, save=False, verbose=False
)
def test_ratios_sum_exactly_1_valid(self):
"""train_ratio + val_ratio = 1 is valid (empty test set)."""
df = pd.DataFrame({"timestamp": list(range(10))})
train, val, test = create_temporal_splits(
df, train_ratio=0.7, val_ratio=0.3, save=False, verbose=False
)
assert len(train) == 7
assert len(val) == 3
assert len(test) == 0
def test_zero_train_ratio_valid(self):
"""train_ratio=0 is valid (all data in val/test)."""
df = pd.DataFrame({"timestamp": list(range(10))})
train, val, test = create_temporal_splits(
df, train_ratio=0.0, val_ratio=0.5, save=False, verbose=False
)
assert len(train) == 0
assert len(val) == 5
assert len(test) == 5
def test_zero_val_ratio_valid(self):
"""val_ratio=0 is valid (no validation set)."""
df = pd.DataFrame({"timestamp": list(range(10))})
train, val, test = create_temporal_splits(
df, train_ratio=0.7, val_ratio=0.0, save=False, verbose=False
)
assert len(train) == 7
assert len(val) == 0
assert len(test) == 3
def test_warns_on_empty_val_split(self, caplog):
"""Should warn when validation split ends up empty."""
# With n=3 and default ratios, val might be empty
df = pd.DataFrame({"timestamp": [1, 2, 3]})
create_temporal_splits(
df, train_ratio=0.9, val_ratio=0.05, save=False, verbose=False
)
assert "Validation split is empty" in caplog.text
def test_warns_on_empty_test_split(self, caplog):
"""Should warn when test split ends up empty."""
df = pd.DataFrame({"timestamp": list(range(10))})
create_temporal_splits(
df, train_ratio=0.7, val_ratio=0.3, save=False, verbose=False
)
assert "Test split is empty" in caplog.text
def test_saves_to_disk_when_requested(self, tmp_path, monkeypatch):
"""Should save splits to parquet files when save=True."""
monkeypatch.setattr("sage.data.loader.SPLITS_DIR", tmp_path)
df = pd.DataFrame(
{
"timestamp": list(range(10)),
"data": list(range(10)),
}
)
create_temporal_splits(df, save=True, verbose=False)
assert (tmp_path / "train.parquet").exists()
assert (tmp_path / "val.parquet").exists()
assert (tmp_path / "test.parquet").exists()
def test_verbose_logs_split_sizes(self, caplog):
"""Verbose mode should log split sizes."""
df = pd.DataFrame({"timestamp": list(range(100))})
create_temporal_splits(df, save=False, verbose=True)
assert "Train:" in caplog.text
assert "Val:" in caplog.text
assert "Test:" in caplog.text
def test_all_data_preserved(self):
"""Total rows across splits should equal original."""
df = pd.DataFrame(
{
"timestamp": list(range(1000)),
"data": list(range(1000)),
}
)
train, val, test = create_temporal_splits(df, save=False, verbose=False)
assert len(train) + len(val) + len(test) == len(df)
def test_no_data_leakage_across_splits(self):
"""No row should appear in multiple splits."""
df = pd.DataFrame(
{
"timestamp": list(range(100)),
"id": [f"row_{i}" for i in range(100)],
}
)
train, val, test = create_temporal_splits(df, save=False, verbose=False)
train_ids = set(train["id"])
val_ids = set(val["id"])
test_ids = set(test["id"])
assert train_ids.isdisjoint(val_ids)
assert train_ids.isdisjoint(test_ids)
assert val_ids.isdisjoint(test_ids)
class TestFilter5Core:
"""Tests for filter_5_core function."""
def test_basic_filtering(self):
"""Users and items with < min_interactions are removed."""
# Create data where some users/items have < 5 interactions
df = pd.DataFrame(
{
"user_id": ["u1"] * 10 + ["u2"] * 3 + ["u3"] * 7, # u2 has only 3
"parent_asin": ["p1"] * 8 + ["p2"] * 2 + ["p3"] * 10, # p2 has only 2
}
)
result = filter_5_core(df, min_interactions=5)
# u2 and p2 should be removed
assert "u2" not in result["user_id"].values
assert "p2" not in result["parent_asin"].values
def test_convergence_required(self):
"""Filtering iterates until no more removals possible."""
# User u1 has 5 interactions, but 3 are with p2 which gets removed
# After p2 removal, u1 only has 2 left → also removed
df = pd.DataFrame(
{
"user_id": ["u1"] * 5 + ["u2"] * 10,
"parent_asin": ["p1"] * 2 + ["p2"] * 3 + ["p1"] * 10, # p2 has only 3
}
)
result = filter_5_core(df, min_interactions=5)
# p2 removed (only 3) → u1 now has only 2 with p1 → u1 removed
assert "u1" not in result["user_id"].values
assert "p2" not in result["parent_asin"].values
def test_empty_input_returns_empty(self):
"""Empty DataFrame returns empty DataFrame."""
df = pd.DataFrame({"user_id": [], "parent_asin": []})
result = filter_5_core(df)
assert result.empty
def test_all_filtered_out_returns_empty(self):
"""When all users/items have < min_interactions, returns empty."""
df = pd.DataFrame(
{
"user_id": ["u1", "u2", "u3"],
"parent_asin": ["p1", "p2", "p3"],
}
)
result = filter_5_core(df, min_interactions=5)
assert result.empty
def test_min_interactions_1_keeps_all(self):
"""min_interactions=1 keeps all data."""
df = pd.DataFrame(
{
"user_id": ["u1", "u2", "u3"],
"parent_asin": ["p1", "p2", "p3"],
}
)
result = filter_5_core(df, min_interactions=1)
assert len(result) == len(df)
def test_preserves_other_columns(self):
"""Non-filter columns are preserved."""
df = pd.DataFrame(
{
"user_id": ["u1"] * 10,
"parent_asin": ["p1"] * 10,
"rating": [4.0] * 10,
"text": ["review"] * 10,
}
)
result = filter_5_core(df, min_interactions=5)
assert "rating" in result.columns
assert "text" in result.columns
assert result["rating"].iloc[0] == 4.0
def test_resets_index(self):
"""Result has clean 0-based index."""
df = pd.DataFrame(
{
"user_id": ["u1"] * 10,
"parent_asin": ["p1"] * 10,
}
)
df.index = range(100, 110) # Non-zero starting index
result = filter_5_core(df, min_interactions=5)
assert list(result.index) == list(range(len(result)))
def test_logs_retention_stats(self, caplog):
"""Should log retention percentage."""
df = pd.DataFrame(
{
"user_id": ["u1"] * 10 + ["u2"] * 10,
"parent_asin": ["p1"] * 10 + ["p2"] * 10,
}
)
filter_5_core(df, min_interactions=5)
assert "retained" in caplog.text
assert "%" in caplog.text
def test_logs_warning_when_all_filtered(self, caplog):
"""Should warn when all data is filtered out."""
df = pd.DataFrame(
{
"user_id": ["u1", "u2", "u3"],
"parent_asin": ["p1", "p2", "p3"],
}
)
filter_5_core(df, min_interactions=5)
assert "All data filtered out" in caplog.text
def test_exact_threshold_kept(self):
"""Users/items with exactly min_interactions are kept."""
df = pd.DataFrame(
{
"user_id": ["u1"] * 5 + ["u2"] * 5, # Both have exactly 5
"parent_asin": ["p1"] * 5 + ["p2"] * 5, # Both have exactly 5
}
)
result = filter_5_core(df, min_interactions=5)
assert len(result) == 10
assert set(result["user_id"]) == {"u1", "u2"}
assert set(result["parent_asin"]) == {"p1", "p2"}
def test_large_min_interactions(self):
"""Large min_interactions filters aggressively."""
df = pd.DataFrame(
{
"user_id": ["u1"] * 50 + ["u2"] * 10,
"parent_asin": ["p1"] * 30 + ["p2"] * 30,
}
)
result = filter_5_core(df, min_interactions=20)
# Only u1 (50) and p1 (30 from u1) survive
# But wait - p1 only has 30, and after u2 is filtered,
# p1's count from u1 depends on which product u1 interacted with
# Let me recalculate: u1 has 50 interactions (all on p1 if we look at pattern)
# Actually the pattern shows first 50 are u1, split across p1(30) and p2(30)
# So u1 has 50 total, u2 has 10 total
# p1 has 30, p2 has 30
# With min=20: u2(10) removed. After u2 removal, p1 still has ~30, p2 still has ~30
# Actually the indexing: u1*50 maps to p1*30 + p2*30 means u1 gets 30 on p1, 20 on p2
# and u2 gets 0 on p1, 10 on p2
# So p1=30 (all u1), p2=30 (20 u1 + 10 u2)
# Remove u2 (10 interactions) → p2 now has 20
# All survive with min=20
assert len(result) > 0
def test_handles_single_user(self):
"""Single user with enough items works."""
df = pd.DataFrame(
{
"user_id": ["u1"] * 10,
"parent_asin": [f"p{i}" for i in range(10)], # 10 different items
}
)
result = filter_5_core(df, min_interactions=5)
# u1 has 10 interactions, but each item only has 1
# Items get filtered → u1 gets filtered
assert result.empty
def test_dense_interaction_matrix(self):
"""Dense data where everyone interacts with everything."""
users = ["u1", "u2", "u3"]
items = ["p1", "p2", "p3"]
# Each user interacts with each item twice = 6 per user, 6 per item
data = []
for u in users:
for p in items:
data.extend([{"user_id": u, "parent_asin": p}] * 2)
df = pd.DataFrame(data)
result = filter_5_core(df, min_interactions=5)
# 6 interactions per user, 6 per item - all kept
assert len(result) == len(df)