"""Unit tests for the remove_unknown_date_exacerbations function.""" import copd import numpy as np import pandas as pd import pytest @pytest.fixture def input_df(): """Sample input data including an exacerbation with an uncertain date.""" return pd.DataFrame({'Date': pd.date_range('2022-01-01', '2022-01-10'), 'IsExac': [0, 1, 0, 0, 0, 0, 1, 0, 1, 0], 'ExacDateUnknown': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]}) def test_check_correct_rows_flagged_default(input_df): """Check the correct rows are flagged for removal using default option (7 days).""" output_df = copd.remove_unknown_date_exacerbations(input_df) expected_df = pd.DataFrame({'Date': pd.date_range('2022-01-01', '2022-01-10'), 'IsExac': [0, 1, 0, 0, 0, 0, 1, 0, 1, 0], 'ExacDateUnknown': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 'RemoveRow': [np.nan, np.nan, 1, 1, 1, 1, 1, 1, 1, np.nan]}) pd.testing.assert_frame_equal(output_df, expected_df) def test_check_correct_rows_flagged_non_default(input_df): """Check the correct rows are flagged for removal when specifying 5 days.""" output_df = copd.remove_unknown_date_exacerbations(input_df, days_to_remove=5) expected_df = pd.DataFrame({'Date': pd.date_range('2022-01-01', '2022-01-10'), 'IsExac': [0, 1, 0, 0, 0, 0, 1, 0, 1, 0], 'ExacDateUnknown': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 'RemoveRow': [np.nan, np.nan, np.nan, np.nan, 1, 1, 1, 1, 1, np.nan]}) pd.testing.assert_frame_equal(output_df, expected_df)