| import numpy as np |
| import pytest |
|
|
| from pandas import ( |
| DataFrame, |
| Series, |
| ) |
| import pandas._testing as tm |
|
|
| pytestmark = pytest.mark.single_cpu |
|
|
| pytest.importorskip("numba") |
|
|
|
|
| @pytest.mark.filterwarnings("ignore") |
| |
| class TestEWM: |
| def test_invalid_update(self): |
| df = DataFrame({"a": range(5), "b": range(5)}) |
| online_ewm = df.head(2).ewm(0.5).online() |
| with pytest.raises( |
| ValueError, |
| match="Must call mean with update=None first before passing update", |
| ): |
| online_ewm.mean(update=df.head(1)) |
|
|
| @pytest.mark.slow |
| @pytest.mark.parametrize( |
| "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")] |
| ) |
| def test_online_vs_non_online_mean( |
| self, obj, nogil, parallel, nopython, adjust, ignore_na |
| ): |
| expected = obj.ewm(0.5, adjust=adjust, ignore_na=ignore_na).mean() |
| engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} |
|
|
| online_ewm = ( |
| obj.head(2) |
| .ewm(0.5, adjust=adjust, ignore_na=ignore_na) |
| .online(engine_kwargs=engine_kwargs) |
| ) |
| |
| for _ in range(2): |
| result = online_ewm.mean() |
| tm.assert_equal(result, expected.head(2)) |
|
|
| result = online_ewm.mean(update=obj.tail(3)) |
| tm.assert_equal(result, expected.tail(3)) |
|
|
| online_ewm.reset() |
|
|
| @pytest.mark.xfail(raises=NotImplementedError) |
| @pytest.mark.parametrize( |
| "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")] |
| ) |
| def test_update_times_mean( |
| self, obj, nogil, parallel, nopython, adjust, ignore_na, halflife_with_times |
| ): |
| times = Series( |
| np.array( |
| ["2020-01-01", "2020-01-05", "2020-01-07", "2020-01-17", "2020-01-21"], |
| dtype="datetime64[ns]", |
| ) |
| ) |
| expected = obj.ewm( |
| 0.5, |
| adjust=adjust, |
| ignore_na=ignore_na, |
| times=times, |
| halflife=halflife_with_times, |
| ).mean() |
|
|
| engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} |
| online_ewm = ( |
| obj.head(2) |
| .ewm( |
| 0.5, |
| adjust=adjust, |
| ignore_na=ignore_na, |
| times=times.head(2), |
| halflife=halflife_with_times, |
| ) |
| .online(engine_kwargs=engine_kwargs) |
| ) |
| |
| for _ in range(2): |
| result = online_ewm.mean() |
| tm.assert_equal(result, expected.head(2)) |
|
|
| result = online_ewm.mean(update=obj.tail(3), update_times=times.tail(3)) |
| tm.assert_equal(result, expected.tail(3)) |
|
|
| online_ewm.reset() |
|
|
| @pytest.mark.parametrize("method", ["aggregate", "std", "corr", "cov", "var"]) |
| def test_ewm_notimplementederror_raises(self, method): |
| ser = Series(range(10)) |
| kwargs = {} |
| if method == "aggregate": |
| kwargs["func"] = lambda x: x |
|
|
| with pytest.raises(NotImplementedError, match=".* is not implemented."): |
| getattr(ser.ewm(1).online(), method)(**kwargs) |
|
|