| """ |
| Tests for the DataGenerator module. |
| |
| This module contains comprehensive tests for the DataGenerator class to ensure |
| proper data generation, splitting, and file operations. |
| """ |
|
|
| import pytest |
| import pandas as pd |
| import numpy as np |
| import tempfile |
| import os |
| from src.data_generator import DataGenerator |
|
|
|
|
| class TestDataGenerator: |
| """Test cases for DataGenerator class.""" |
|
|
| def test_initialization(self): |
| """Test DataGenerator initialization with and without seed.""" |
| |
| generator = DataGenerator(seed=42) |
| assert generator.seed == 42 |
|
|
| |
| generator_no_seed = DataGenerator(seed=None) |
| assert generator_no_seed.seed is None |
|
|
| def test_generate_data_basic(self): |
| """Test basic data generation with default parameters.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data() |
|
|
| |
| assert isinstance(data, pd.DataFrame) |
| assert len(data) == 1000 |
| assert list(data.columns) == [ |
| "temperature", |
| "day_of_week", |
| "major_event", |
| "consumption_kwh", |
| ] |
|
|
| |
| assert data["temperature"].dtype in [np.float64, np.float32] |
| assert data["day_of_week"].dtype == "object" |
| assert data["major_event"].dtype in [np.int64, np.int32] |
| assert data["consumption_kwh"].dtype in [np.float64, np.float32] |
|
|
| def test_generate_data_custom_parameters(self): |
| """Test data generation with custom parameters.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=500, noise_level=0.2) |
|
|
| assert len(data) == 500 |
|
|
| |
| assert data["temperature"].min() >= 15 |
| assert data["temperature"].max() <= 35 |
|
|
| |
| valid_days = [ |
| "Monday", |
| "Tuesday", |
| "Wednesday", |
| "Thursday", |
| "Friday", |
| "Saturday", |
| "Sunday", |
| ] |
| assert all(day in valid_days for day in data["day_of_week"].unique()) |
|
|
| |
| assert all(event in [0, 1] for event in data["major_event"].unique()) |
|
|
| |
| assert all(data["consumption_kwh"] > 0) |
|
|
| def test_generate_data_reproducibility(self): |
| """Test that data generation is reproducible with the same seed.""" |
| |
| np.random.seed(42) |
|
|
| generator1 = DataGenerator(seed=42) |
| data1 = generator1.generate_data(n_samples=100) |
|
|
| |
| np.random.seed(42) |
|
|
| generator2 = DataGenerator(seed=42) |
| data2 = generator2.generate_data(n_samples=100) |
|
|
| pd.testing.assert_frame_equal(data1, data2) |
|
|
| def test_generate_data_different_seeds(self): |
| """Test that different seeds produce different data.""" |
| generator1 = DataGenerator(seed=42) |
| generator2 = DataGenerator(seed=123) |
|
|
| data1 = generator1.generate_data(n_samples=100) |
| data2 = generator2.generate_data(n_samples=100) |
|
|
| |
| assert not data1.equals(data2) |
|
|
| def test_split_data_basic(self): |
| """Test basic data splitting functionality.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| train_data, val_data, test_data = generator.split_data(data) |
|
|
| |
| assert len(train_data) == 700 |
| assert len(val_data) == 150 |
| assert len(test_data) == 150 |
|
|
| |
| assert len(train_data) + len(val_data) + len(test_data) == len(data) |
|
|
| |
| all_data = pd.concat([train_data, val_data, test_data]) |
| assert len(all_data) == len(data) |
|
|
| def test_split_data_custom_proportions(self): |
| """Test data splitting with custom proportions.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| train_data, val_data, test_data = generator.split_data( |
| data, train_size=0.6, val_size=0.2, test_size=0.2 |
| ) |
|
|
| assert len(train_data) == 600 |
| assert len(val_data) == 200 |
| assert len(test_data) == 200 |
|
|
| def test_split_data_validation(self): |
| """Test that split proportions validation works.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=100) |
|
|
| |
| with pytest.raises(AssertionError): |
| generator.split_data(data, train_size=0.5, val_size=0.3, test_size=0.3) |
|
|
| with pytest.raises(AssertionError): |
| generator.split_data(data, train_size=0.4, val_size=0.3, test_size=0.2) |
|
|
| def test_split_data_reproducibility(self): |
| """Test that data splitting is reproducible.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| |
| train1, val1, test1 = generator.split_data(data) |
|
|
| |
| train2, val2, test2 = generator.split_data(data) |
|
|
| |
| pd.testing.assert_frame_equal(train1, train2) |
| pd.testing.assert_frame_equal(val1, val2) |
| pd.testing.assert_frame_equal(test1, test2) |
|
|
| def test_save_and_load_data(self): |
| """Test saving and loading data to/from CSV.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=100) |
|
|
| with tempfile.NamedTemporaryFile( |
| mode="w", suffix=".csv", delete=False |
| ) as tmp_file: |
| filepath = tmp_file.name |
|
|
| try: |
| |
| generator.save_data(data, filepath) |
|
|
| |
| assert os.path.exists(filepath) |
|
|
| |
| loaded_data = generator.load_data(filepath) |
|
|
| |
| pd.testing.assert_frame_equal(data, loaded_data) |
|
|
| finally: |
| |
| if os.path.exists(filepath): |
| os.unlink(filepath) |
|
|
| def test_data_statistics(self): |
| """Test that generated data has reasonable statistics.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| |
| assert 15 <= data["temperature"].mean() <= 35 |
| assert data["temperature"].std() > 0 |
|
|
| |
| assert data["consumption_kwh"].mean() > 0 |
| assert data["consumption_kwh"].std() > 0 |
|
|
| |
| day_counts = data["day_of_week"].value_counts() |
| assert len(day_counts) == 7 |
| |
| assert all(count > 0 for count in day_counts.values) |
|
|
| |
| event_counts = data["major_event"].value_counts() |
| assert 0 in event_counts.index |
| assert 1 in event_counts.index |
| |
| assert event_counts[0] > event_counts[1] |
|
|
| def test_noise_level_effect(self): |
| """Test that noise level affects data variability.""" |
| generator = DataGenerator(seed=42) |
|
|
| |
| data_low_noise = generator.generate_data(n_samples=1000, noise_level=0.01) |
|
|
| |
| data_high_noise = generator.generate_data(n_samples=1000, noise_level=0.5) |
|
|
| |
| assert ( |
| data_high_noise["consumption_kwh"].std() |
| > data_low_noise["consumption_kwh"].std() |
| ) |
|
|
| def test_temperature_consumption_correlation(self): |
| """Test that temperature and consumption have positive correlation.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| correlation = data["temperature"].corr(data["consumption_kwh"]) |
| assert correlation > 0 |
|
|
| def test_day_of_week_effect(self): |
| """Test that different days have different consumption patterns.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| |
| day_consumption = data.groupby("day_of_week")["consumption_kwh"].mean() |
|
|
| |
| assert day_consumption.std() > 0 |
|
|
| |
| weekend_avg = (day_consumption["Saturday"] + day_consumption["Sunday"]) / 2 |
| weekday_avg = ( |
| day_consumption["Monday"] |
| + day_consumption["Tuesday"] |
| + day_consumption["Wednesday"] |
| + day_consumption["Thursday"] |
| + day_consumption["Friday"] |
| ) / 5 |
|
|
| |
| |
| assert abs(weekend_avg - weekday_avg) > 0.1 |
|
|
| def test_major_event_effect(self): |
| """Test that major events increase consumption.""" |
| generator = DataGenerator(seed=42) |
| data = generator.generate_data(n_samples=1000) |
|
|
| |
| event_consumption = data.groupby("major_event")["consumption_kwh"].mean() |
|
|
| |
| assert event_consumption[1] > event_consumption[0] |
|
|