| """ |
| Metadata Routing Utility Tests |
| """ |
|
|
| |
| |
|
|
| import re |
|
|
| import numpy as np |
| import pytest |
|
|
| from sklearn import config_context |
| from sklearn.base import ( |
| BaseEstimator, |
| clone, |
| ) |
| from sklearn.exceptions import UnsetMetadataPassedError |
| from sklearn.linear_model import LinearRegression |
| from sklearn.pipeline import Pipeline |
| from sklearn.tests.metadata_routing_common import ( |
| ConsumingClassifier, |
| ConsumingRegressor, |
| ConsumingTransformer, |
| MetaRegressor, |
| MetaTransformer, |
| NonConsumingClassifier, |
| WeightedMetaClassifier, |
| WeightedMetaRegressor, |
| _Registry, |
| assert_request_equal, |
| assert_request_is_empty, |
| check_recorded_metadata, |
| ) |
| from sklearn.utils import metadata_routing |
| from sklearn.utils._metadata_requests import ( |
| COMPOSITE_METHODS, |
| METHODS, |
| SIMPLE_METHODS, |
| MethodMetadataRequest, |
| MethodPair, |
| _MetadataRequester, |
| request_is_alias, |
| request_is_valid, |
| ) |
| from sklearn.utils.metadata_routing import ( |
| MetadataRequest, |
| MetadataRouter, |
| MethodMapping, |
| _RoutingNotSupportedMixin, |
| get_routing_for_object, |
| process_routing, |
| ) |
| from sklearn.utils.validation import check_is_fitted |
|
|
| rng = np.random.RandomState(42) |
| N, M = 100, 4 |
| X = rng.rand(N, M) |
| y = rng.randint(0, 2, size=N) |
| my_groups = rng.randint(0, 10, size=N) |
| my_weights = rng.rand(N) |
| my_other_weights = rng.rand(N) |
|
|
|
|
| class SimplePipeline(BaseEstimator): |
| """A very simple pipeline, assuming the last step is always a predictor. |
| |
| Parameters |
| ---------- |
| steps : iterable of objects |
| An iterable of transformers with the last step being a predictor. |
| """ |
|
|
| def __init__(self, steps): |
| self.steps = steps |
|
|
| def fit(self, X, y, **fit_params): |
| self.steps_ = [] |
| params = process_routing(self, "fit", **fit_params) |
| X_transformed = X |
| for i, step in enumerate(self.steps[:-1]): |
| transformer = clone(step).fit( |
| X_transformed, y, **params.get(f"step_{i}").fit |
| ) |
| self.steps_.append(transformer) |
| X_transformed = transformer.transform( |
| X_transformed, **params.get(f"step_{i}").transform |
| ) |
|
|
| self.steps_.append( |
| clone(self.steps[-1]).fit(X_transformed, y, **params.predictor.fit) |
| ) |
| return self |
|
|
| def predict(self, X, **predict_params): |
| check_is_fitted(self) |
| X_transformed = X |
| params = process_routing(self, "predict", **predict_params) |
| for i, step in enumerate(self.steps_[:-1]): |
| X_transformed = step.transform(X, **params.get(f"step_{i}").transform) |
|
|
| return self.steps_[-1].predict(X_transformed, **params.predictor.predict) |
|
|
| def get_metadata_routing(self): |
| router = MetadataRouter(owner=self.__class__.__name__) |
| for i, step in enumerate(self.steps[:-1]): |
| router.add( |
| **{f"step_{i}": step}, |
| method_mapping=MethodMapping() |
| .add(caller="fit", callee="fit") |
| .add(caller="fit", callee="transform") |
| .add(caller="predict", callee="transform"), |
| ) |
| router.add( |
| predictor=self.steps[-1], |
| method_mapping=MethodMapping() |
| .add(caller="fit", callee="fit") |
| .add(caller="predict", callee="predict"), |
| ) |
| return router |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_assert_request_is_empty(): |
| requests = MetadataRequest(owner="test") |
| assert_request_is_empty(requests) |
|
|
| requests.fit.add_request(param="foo", alias=None) |
| |
| assert_request_is_empty(requests) |
|
|
| requests.fit.add_request(param="bar", alias="value") |
| with pytest.raises(AssertionError): |
| |
| assert_request_is_empty(requests) |
|
|
| |
| assert_request_is_empty(requests, exclude="fit") |
|
|
| requests.score.add_request(param="carrot", alias=True) |
| with pytest.raises(AssertionError): |
| |
| assert_request_is_empty(requests, exclude="fit") |
|
|
| |
| assert_request_is_empty(requests, exclude=["fit", "score"]) |
|
|
| |
| assert_request_is_empty( |
| MetadataRouter(owner="test") |
| .add_self_request(WeightedMetaRegressor(estimator=None)) |
| .add( |
| estimator=ConsumingRegressor(), |
| method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
| ) |
| ) |
|
|
|
|
| @pytest.mark.parametrize( |
| "estimator", |
| [ |
| ConsumingClassifier(registry=_Registry()), |
| ConsumingRegressor(registry=_Registry()), |
| ConsumingTransformer(registry=_Registry()), |
| WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()), |
| WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()), |
| ], |
| ) |
| @config_context(enable_metadata_routing=True) |
| def test_estimator_puts_self_in_registry(estimator): |
| """Check that an estimator puts itself in the registry upon fit.""" |
| estimator.fit(X, y) |
| assert estimator in estimator.registry |
|
|
|
|
| @pytest.mark.parametrize( |
| "val, res", |
| [ |
| (False, False), |
| (True, False), |
| (None, False), |
| ("$UNUSED$", False), |
| ("$WARN$", False), |
| ("invalid-input", False), |
| ("valid_arg", True), |
| ], |
| ) |
| @config_context(enable_metadata_routing=True) |
| def test_request_type_is_alias(val, res): |
| |
| assert request_is_alias(val) == res |
|
|
|
|
| @pytest.mark.parametrize( |
| "val, res", |
| [ |
| (False, True), |
| (True, True), |
| (None, True), |
| ("$UNUSED$", True), |
| ("$WARN$", True), |
| ("invalid-input", False), |
| ("alias_arg", False), |
| ], |
| ) |
| @config_context(enable_metadata_routing=True) |
| def test_request_type_is_valid(val, res): |
| |
| assert request_is_valid(val) == res |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_default_requests(): |
| class OddEstimator(BaseEstimator): |
| __metadata_request__fit = { |
| |
| "sample_weight": True |
| } |
|
|
| odd_request = get_routing_for_object(OddEstimator()) |
| assert odd_request.fit.requests == {"sample_weight": True} |
|
|
| |
| assert not len(get_routing_for_object(NonConsumingClassifier()).fit.requests) |
| assert_request_is_empty(NonConsumingClassifier().get_metadata_routing()) |
|
|
| trs_request = get_routing_for_object(ConsumingTransformer()) |
| assert trs_request.fit.requests == { |
| "sample_weight": None, |
| "metadata": None, |
| } |
| assert trs_request.transform.requests == {"metadata": None, "sample_weight": None} |
| assert_request_is_empty(trs_request) |
|
|
| est_request = get_routing_for_object(ConsumingClassifier()) |
| assert est_request.fit.requests == { |
| "sample_weight": None, |
| "metadata": None, |
| } |
| assert_request_is_empty(est_request) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_default_request_override(): |
| """Test that default requests are correctly overridden regardless of the ASCII order |
| of the class names, hence testing small and capital letter class name starts. |
| Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28430 |
| """ |
|
|
| class Base(BaseEstimator): |
| __metadata_request__split = {"groups": True} |
|
|
| class class_1(Base): |
| __metadata_request__split = {"groups": "sample_domain"} |
|
|
| class Class_1(Base): |
| __metadata_request__split = {"groups": "sample_domain"} |
|
|
| assert_request_equal( |
| class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}} |
| ) |
| assert_request_equal( |
| Class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}} |
| ) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_process_routing_invalid_method(): |
| with pytest.raises(TypeError, match="Can only route and process input"): |
| process_routing(ConsumingClassifier(), "invalid_method", groups=my_groups) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_process_routing_invalid_object(): |
| class InvalidObject: |
| pass |
|
|
| with pytest.raises(AttributeError, match="either implement the routing method"): |
| process_routing(InvalidObject(), "fit", groups=my_groups) |
|
|
|
|
| @pytest.mark.parametrize("method", METHODS) |
| @pytest.mark.parametrize("default", [None, "default", []]) |
| @config_context(enable_metadata_routing=True) |
| def test_process_routing_empty_params_get_with_default(method, default): |
| empty_params = {} |
| routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params) |
|
|
| |
| params_for_method = routed_params[method] |
| assert isinstance(params_for_method, dict) |
| assert set(params_for_method.keys()) == set(METHODS) |
|
|
| |
| default_params_for_method = routed_params.get(method, default=default) |
| assert default_params_for_method == params_for_method |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_simple_metadata_routing(): |
| |
|
|
| |
| clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) |
| clf.fit(X, y) |
|
|
| |
| |
| clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) |
| clf.fit(X, y, sample_weight=my_weights) |
|
|
| |
| |
| clf = WeightedMetaClassifier(estimator=ConsumingClassifier()) |
| err_message = ( |
| "[sample_weight] are passed but are not explicitly set as requested or" |
| " not requested for ConsumingClassifier.fit" |
| ) |
| with pytest.raises(ValueError, match=re.escape(err_message)): |
| clf.fit(X, y, sample_weight=my_weights) |
|
|
| |
| |
| |
| |
| clf = WeightedMetaClassifier( |
| estimator=ConsumingClassifier().set_fit_request(sample_weight=False) |
| ) |
| |
| |
| |
| clf.fit(X, y, sample_weight=my_weights) |
| check_recorded_metadata(clf.estimator_, method="fit", parent="fit") |
|
|
| |
| clf = WeightedMetaClassifier( |
| estimator=ConsumingClassifier().set_fit_request(sample_weight=True) |
| ) |
| clf.fit(X, y, sample_weight=my_weights) |
| check_recorded_metadata( |
| clf.estimator_, method="fit", parent="fit", sample_weight=my_weights |
| ) |
|
|
| |
| clf = WeightedMetaClassifier( |
| estimator=ConsumingClassifier().set_fit_request( |
| sample_weight="alternative_weight" |
| ) |
| ) |
| clf.fit(X, y, alternative_weight=my_weights) |
| check_recorded_metadata( |
| clf.estimator_, method="fit", parent="fit", sample_weight=my_weights |
| ) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_nested_routing(): |
| |
| pipeline = SimplePipeline( |
| [ |
| MetaTransformer( |
| transformer=ConsumingTransformer() |
| .set_fit_request(metadata=True, sample_weight=False) |
| .set_transform_request(sample_weight=True, metadata=False) |
| ), |
| WeightedMetaRegressor( |
| estimator=ConsumingRegressor() |
| .set_fit_request(sample_weight="inner_weights", metadata=False) |
| .set_predict_request(sample_weight=False) |
| ).set_fit_request(sample_weight="outer_weights"), |
| ] |
| ) |
| w1, w2, w3 = [1], [2], [3] |
| pipeline.fit( |
| X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3 |
| ) |
| check_recorded_metadata( |
| pipeline.steps_[0].transformer_, |
| method="fit", |
| parent="fit", |
| metadata=my_groups, |
| ) |
| check_recorded_metadata( |
| pipeline.steps_[0].transformer_, |
| method="transform", |
| parent="fit", |
| sample_weight=w1, |
| ) |
| check_recorded_metadata( |
| pipeline.steps_[1], method="fit", parent="fit", sample_weight=w2 |
| ) |
| check_recorded_metadata( |
| pipeline.steps_[1].estimator_, method="fit", parent="fit", sample_weight=w3 |
| ) |
|
|
| pipeline.predict(X, sample_weight=w3) |
| check_recorded_metadata( |
| pipeline.steps_[0].transformer_, |
| method="transform", |
| parent="fit", |
| sample_weight=w3, |
| ) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_nested_routing_conflict(): |
| |
| pipeline = SimplePipeline( |
| [ |
| MetaTransformer( |
| transformer=ConsumingTransformer() |
| .set_fit_request(metadata=True, sample_weight=False) |
| .set_transform_request(sample_weight=True) |
| ), |
| WeightedMetaRegressor( |
| estimator=ConsumingRegressor().set_fit_request(sample_weight=True) |
| ).set_fit_request(sample_weight="outer_weights"), |
| ] |
| ) |
| w1, w2 = [1], [2] |
| with pytest.raises( |
| ValueError, |
| match=( |
| re.escape( |
| "In WeightedMetaRegressor, there is a conflict on sample_weight between" |
| " what is requested for this estimator and what is requested by its" |
| " children. You can resolve this conflict by using an alias for the" |
| " child estimator(s) requested metadata." |
| ) |
| ), |
| ): |
| pipeline.fit(X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_invalid_metadata(): |
| |
| trs = MetaTransformer( |
| transformer=ConsumingTransformer().set_transform_request(sample_weight=True) |
| ) |
| with pytest.raises( |
| TypeError, |
| match=(re.escape("transform got unexpected argument(s) {'other_param'}")), |
| ): |
| trs.fit(X, y).transform(X, other_param=my_weights) |
|
|
| |
| trs = MetaTransformer( |
| transformer=ConsumingTransformer().set_transform_request(sample_weight=False) |
| ) |
| with pytest.raises( |
| TypeError, |
| match=(re.escape("transform got unexpected argument(s) {'sample_weight'}")), |
| ): |
| trs.fit(X, y).transform(X, sample_weight=my_weights) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_get_metadata_routing(): |
| class TestDefaultsBadMethodName(_MetadataRequester): |
| __metadata_request__fit = { |
| "sample_weight": None, |
| "my_param": None, |
| } |
| __metadata_request__score = { |
| "sample_weight": None, |
| "my_param": True, |
| "my_other_param": None, |
| } |
| |
| __metadata_request__other_method = {"my_param": True} |
|
|
| class TestDefaults(_MetadataRequester): |
| __metadata_request__fit = { |
| "sample_weight": None, |
| "my_other_param": None, |
| } |
| __metadata_request__score = { |
| "sample_weight": None, |
| "my_param": True, |
| "my_other_param": None, |
| } |
| __metadata_request__predict = {"my_param": True} |
|
|
| with pytest.raises( |
| AttributeError, match="'MetadataRequest' object has no attribute 'other_method'" |
| ): |
| TestDefaultsBadMethodName().get_metadata_routing() |
|
|
| expected = { |
| "score": { |
| "my_param": True, |
| "my_other_param": None, |
| "sample_weight": None, |
| }, |
| "fit": { |
| "my_other_param": None, |
| "sample_weight": None, |
| }, |
| "predict": {"my_param": True}, |
| } |
| assert_request_equal(TestDefaults().get_metadata_routing(), expected) |
|
|
| est = TestDefaults().set_score_request(my_param="other_param") |
| expected = { |
| "score": { |
| "my_param": "other_param", |
| "my_other_param": None, |
| "sample_weight": None, |
| }, |
| "fit": { |
| "my_other_param": None, |
| "sample_weight": None, |
| }, |
| "predict": {"my_param": True}, |
| } |
| assert_request_equal(est.get_metadata_routing(), expected) |
|
|
| est = TestDefaults().set_fit_request(sample_weight=True) |
| expected = { |
| "score": { |
| "my_param": True, |
| "my_other_param": None, |
| "sample_weight": None, |
| }, |
| "fit": { |
| "my_other_param": None, |
| "sample_weight": True, |
| }, |
| "predict": {"my_param": True}, |
| } |
| assert_request_equal(est.get_metadata_routing(), expected) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_setting_default_requests(): |
| |
| test_cases = dict() |
|
|
| class ExplicitRequest(BaseEstimator): |
| |
| __metadata_request__fit = {"prop": None} |
|
|
| def fit(self, X, y, **kwargs): |
| return self |
|
|
| test_cases[ExplicitRequest] = {"prop": None} |
|
|
| class ExplicitRequestOverwrite(BaseEstimator): |
| |
| |
| __metadata_request__fit = {"prop": True} |
|
|
| def fit(self, X, y, prop=None, **kwargs): |
| return self |
|
|
| test_cases[ExplicitRequestOverwrite] = {"prop": True} |
|
|
| class ImplicitRequest(BaseEstimator): |
| |
| def fit(self, X, y, prop=None, **kwargs): |
| return self |
|
|
| test_cases[ImplicitRequest] = {"prop": None} |
|
|
| class ImplicitRequestRemoval(BaseEstimator): |
| |
| |
| __metadata_request__fit = {"prop": metadata_routing.UNUSED} |
|
|
| def fit(self, X, y, prop=None, **kwargs): |
| return self |
|
|
| test_cases[ImplicitRequestRemoval] = {} |
|
|
| for Klass, requests in test_cases.items(): |
| assert get_routing_for_object(Klass()).fit.requests == requests |
| assert_request_is_empty(Klass().get_metadata_routing(), exclude="fit") |
| Klass().fit(None, None) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_removing_non_existing_param_raises(): |
| """Test that removing a metadata using UNUSED which doesn't exist raises.""" |
|
|
| class InvalidRequestRemoval(BaseEstimator): |
| |
| |
| __metadata_request__fit = {"prop": metadata_routing.UNUSED} |
|
|
| def fit(self, X, y, **kwargs): |
| return self |
|
|
| with pytest.raises(ValueError, match="Trying to remove parameter"): |
| InvalidRequestRemoval().get_metadata_routing() |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_method_metadata_request(): |
| mmr = MethodMetadataRequest(owner="test", method="fit") |
|
|
| with pytest.raises(ValueError, match="The alias you're setting for"): |
| mmr.add_request(param="foo", alias=1.4) |
|
|
| mmr.add_request(param="foo", alias=None) |
| assert mmr.requests == {"foo": None} |
| mmr.add_request(param="foo", alias=False) |
| assert mmr.requests == {"foo": False} |
| mmr.add_request(param="foo", alias=True) |
| assert mmr.requests == {"foo": True} |
| mmr.add_request(param="foo", alias="foo") |
| assert mmr.requests == {"foo": True} |
| mmr.add_request(param="foo", alias="bar") |
| assert mmr.requests == {"foo": "bar"} |
| assert mmr._get_param_names(return_alias=False) == {"foo"} |
| assert mmr._get_param_names(return_alias=True) == {"bar"} |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_get_routing_for_object(): |
| class Consumer(BaseEstimator): |
| __metadata_request__fit = {"prop": None} |
|
|
| assert_request_is_empty(get_routing_for_object(None)) |
| assert_request_is_empty(get_routing_for_object(object())) |
|
|
| mr = MetadataRequest(owner="test") |
| mr.fit.add_request(param="foo", alias="bar") |
| mr_factory = get_routing_for_object(mr) |
| assert_request_is_empty(mr_factory, exclude="fit") |
| assert mr_factory.fit.requests == {"foo": "bar"} |
|
|
| mr = get_routing_for_object(Consumer()) |
| assert_request_is_empty(mr, exclude="fit") |
| assert mr.fit.requests == {"prop": None} |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_metadata_request_consumes_method(): |
| """Test that MetadataRequest().consumes() method works as expected.""" |
| request = MetadataRouter(owner="test") |
| assert request.consumes(method="fit", params={"foo"}) == set() |
|
|
| request = MetadataRequest(owner="test") |
| request.fit.add_request(param="foo", alias=True) |
| assert request.consumes(method="fit", params={"foo"}) == {"foo"} |
|
|
| request = MetadataRequest(owner="test") |
| request.fit.add_request(param="foo", alias="bar") |
| assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"} |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_metadata_router_consumes_method(): |
| """Test that MetadataRouter().consumes method works as expected.""" |
| |
| |
| cases = [ |
| ( |
| WeightedMetaRegressor( |
| estimator=ConsumingRegressor().set_fit_request(sample_weight=True) |
| ), |
| {"sample_weight"}, |
| {"sample_weight"}, |
| ), |
| ( |
| WeightedMetaRegressor( |
| estimator=ConsumingRegressor().set_fit_request( |
| sample_weight="my_weights" |
| ) |
| ), |
| {"my_weights", "sample_weight"}, |
| {"my_weights"}, |
| ), |
| ] |
|
|
| for obj, input, output in cases: |
| assert obj.get_metadata_routing().consumes(method="fit", params=input) == output |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_metaestimator_warnings(): |
| class WeightedMetaRegressorWarn(WeightedMetaRegressor): |
| __metadata_request__fit = {"sample_weight": metadata_routing.WARN} |
|
|
| with pytest.warns( |
| UserWarning, match="Support for .* has recently been added to this class" |
| ): |
| WeightedMetaRegressorWarn( |
| estimator=LinearRegression().set_fit_request(sample_weight=False) |
| ).fit(X, y, sample_weight=my_weights) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_estimator_warnings(): |
| class ConsumingRegressorWarn(ConsumingRegressor): |
| __metadata_request__fit = {"sample_weight": metadata_routing.WARN} |
|
|
| with pytest.warns( |
| UserWarning, match="Support for .* has recently been added to this class" |
| ): |
| MetaRegressor(estimator=ConsumingRegressorWarn()).fit( |
| X, y, sample_weight=my_weights |
| ) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| @pytest.mark.parametrize( |
| "obj, string", |
| [ |
| ( |
| MethodMetadataRequest(owner="test", method="fit").add_request( |
| param="foo", alias="bar" |
| ), |
| "{'foo': 'bar'}", |
| ), |
| ( |
| MetadataRequest(owner="test"), |
| "{}", |
| ), |
| ( |
| MetadataRouter(owner="test").add( |
| estimator=ConsumingRegressor(), |
| method_mapping=MethodMapping().add(caller="predict", callee="predict"), |
| ), |
| ( |
| "{'estimator': {'mapping': [{'caller': 'predict', 'callee':" |
| " 'predict'}], 'router': {'fit': {'sample_weight': None, 'metadata':" |
| " None}, 'partial_fit': {'sample_weight': None, 'metadata': None}," |
| " 'predict': {'sample_weight': None, 'metadata': None}, 'score':" |
| " {'sample_weight': None, 'metadata': None}}}}" |
| ), |
| ), |
| ], |
| ) |
| @config_context(enable_metadata_routing=True) |
| def test_string_representations(obj, string): |
| assert str(obj) == string |
|
|
|
|
| @pytest.mark.parametrize( |
| "obj, method, inputs, err_cls, err_msg", |
| [ |
| ( |
| MethodMapping(), |
| "add", |
| {"caller": "fit", "callee": "invalid"}, |
| ValueError, |
| "Given callee", |
| ), |
| ( |
| MethodMapping(), |
| "add", |
| {"caller": "invalid", "callee": "fit"}, |
| ValueError, |
| "Given caller", |
| ), |
| ( |
| MetadataRouter(owner="test"), |
| "add_self_request", |
| {"obj": MetadataRouter(owner="test")}, |
| ValueError, |
| "Given `obj` is neither a `MetadataRequest` nor does it implement", |
| ), |
| ( |
| ConsumingClassifier(), |
| "set_fit_request", |
| {"invalid": True}, |
| TypeError, |
| "Unexpected args", |
| ), |
| ], |
| ) |
| @config_context(enable_metadata_routing=True) |
| def test_validations(obj, method, inputs, err_cls, err_msg): |
| with pytest.raises(err_cls, match=err_msg): |
| getattr(obj, method)(**inputs) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_methodmapping(): |
| mm = ( |
| MethodMapping() |
| .add(caller="fit", callee="transform") |
| .add(caller="fit", callee="fit") |
| ) |
|
|
| mm_list = list(mm) |
| assert mm_list[0] == ("fit", "transform") |
| assert mm_list[1] == ("fit", "fit") |
|
|
| mm = MethodMapping() |
| for method in METHODS: |
| mm.add(caller=method, callee=method) |
| assert MethodPair(method, method) in mm._routes |
| assert len(mm._routes) == len(METHODS) |
|
|
| mm = MethodMapping().add(caller="score", callee="score") |
| assert repr(mm) == "[{'caller': 'score', 'callee': 'score'}]" |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_metadatarouter_add_self_request(): |
| |
| request = MetadataRequest(owner="nested") |
| request.fit.add_request(param="param", alias=True) |
| router = MetadataRouter(owner="test").add_self_request(request) |
| assert str(router._self_request) == str(request) |
| |
| assert router._self_request is not request |
|
|
| |
| est = ConsumingRegressor().set_fit_request(sample_weight="my_weights") |
| router = MetadataRouter(owner="test").add_self_request(obj=est) |
| assert str(router._self_request) == str(est.get_metadata_routing()) |
| assert router._self_request is not est.get_metadata_routing() |
|
|
| |
| est = WeightedMetaRegressor( |
| estimator=ConsumingRegressor().set_fit_request(sample_weight="nested_weights") |
| ) |
| router = MetadataRouter(owner="test").add_self_request(obj=est) |
| |
| assert str(router._self_request) == str(est._get_metadata_request()) |
| |
| |
| assert str(router._self_request) != str(est.get_metadata_routing()) |
| |
| assert router._self_request is not est._get_metadata_request() |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_metadata_routing_add(): |
| |
| router = MetadataRouter(owner="test").add( |
| est=ConsumingRegressor().set_fit_request(sample_weight="weights"), |
| method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
| ) |
| assert ( |
| str(router) |
| == "{'est': {'mapping': [{'caller': 'fit', 'callee': 'fit'}], 'router': {'fit':" |
| " {'sample_weight': 'weights', 'metadata': None}, 'partial_fit':" |
| " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" |
| " None, 'metadata': None}, 'score': {'sample_weight': None, 'metadata':" |
| " None}}}}" |
| ) |
|
|
| |
| router = MetadataRouter(owner="test").add( |
| method_mapping=MethodMapping().add(caller="fit", callee="score"), |
| est=ConsumingRegressor().set_score_request(sample_weight=True), |
| ) |
| assert ( |
| str(router) |
| == "{'est': {'mapping': [{'caller': 'fit', 'callee': 'score'}], 'router':" |
| " {'fit': {'sample_weight': None, 'metadata': None}, 'partial_fit':" |
| " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" |
| " None, 'metadata': None}, 'score': {'sample_weight': True, 'metadata':" |
| " None}}}}" |
| ) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_metadata_routing_get_param_names(): |
| router = ( |
| MetadataRouter(owner="test") |
| .add_self_request( |
| WeightedMetaRegressor(estimator=ConsumingRegressor()).set_fit_request( |
| sample_weight="self_weights" |
| ) |
| ) |
| .add( |
| trs=ConsumingTransformer().set_fit_request( |
| sample_weight="transform_weights" |
| ), |
| method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
| ) |
| ) |
|
|
| assert ( |
| str(router) |
| == "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score':" |
| " {'sample_weight': None}}, 'trs': {'mapping': [{'caller': 'fit', 'callee':" |
| " 'fit'}], 'router': {'fit': {'sample_weight': 'transform_weights'," |
| " 'metadata': None}, 'transform': {'sample_weight': None, 'metadata': None}," |
| " 'inverse_transform': {'sample_weight': None, 'metadata': None}}}}" |
| ) |
|
|
| assert router._get_param_names( |
| method="fit", return_alias=True, ignore_self_request=False |
| ) == {"transform_weights", "metadata", "self_weights"} |
| |
| assert router._get_param_names( |
| method="fit", return_alias=False, ignore_self_request=False |
| ) == {"sample_weight", "metadata", "transform_weights"} |
| |
| assert router._get_param_names( |
| method="fit", return_alias=False, ignore_self_request=True |
| ) == {"metadata", "transform_weights"} |
| |
| assert router._get_param_names( |
| method="fit", return_alias=True, ignore_self_request=True |
| ) == router._get_param_names( |
| method="fit", return_alias=False, ignore_self_request=True |
| ) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_method_generation(): |
| |
|
|
| |
| |
| class SimpleEstimator(BaseEstimator): |
| |
| def fit(self, X, y): |
| pass |
|
|
| def fit_transform(self, X, y): |
| pass |
|
|
| def fit_predict(self, X, y): |
| pass |
|
|
| def partial_fit(self, X, y): |
| pass |
|
|
| def predict(self, X): |
| pass |
|
|
| def predict_proba(self, X): |
| pass |
|
|
| def predict_log_proba(self, X): |
| pass |
|
|
| def decision_function(self, X): |
| pass |
|
|
| def score(self, X, y): |
| pass |
|
|
| def split(self, X, y=None): |
| pass |
|
|
| def transform(self, X): |
| pass |
|
|
| def inverse_transform(self, X): |
| pass |
|
|
| for method in METHODS: |
| assert not hasattr(SimpleEstimator(), f"set_{method}_request") |
|
|
| class SimpleEstimator(BaseEstimator): |
| |
| def fit(self, X, y, sample_weight=None): |
| pass |
|
|
| def fit_transform(self, X, y, sample_weight=None): |
| pass |
|
|
| def fit_predict(self, X, y, sample_weight=None): |
| pass |
|
|
| def partial_fit(self, X, y, sample_weight=None): |
| pass |
|
|
| def predict(self, X, sample_weight=None): |
| pass |
|
|
| def predict_proba(self, X, sample_weight=None): |
| pass |
|
|
| def predict_log_proba(self, X, sample_weight=None): |
| pass |
|
|
| def decision_function(self, X, sample_weight=None): |
| pass |
|
|
| def score(self, X, y, sample_weight=None): |
| pass |
|
|
| def split(self, X, y=None, sample_weight=None): |
| pass |
|
|
| def transform(self, X, sample_weight=None): |
| pass |
|
|
| def inverse_transform(self, X, sample_weight=None): |
| pass |
|
|
| |
| for method in COMPOSITE_METHODS: |
| assert not hasattr(SimpleEstimator(), f"set_{method}_request") |
|
|
| |
| for method in SIMPLE_METHODS: |
| assert hasattr(SimpleEstimator(), f"set_{method}_request") |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_composite_methods(): |
| |
| |
| |
| |
| class SimpleEstimator(BaseEstimator): |
| |
| def fit(self, X, y, foo=None, bar=None): |
| pass |
|
|
| def predict(self, X, foo=None, bar=None): |
| pass |
|
|
| def transform(self, X, other_param=None): |
| pass |
|
|
| est = SimpleEstimator() |
| |
| |
| assert est.get_metadata_routing().fit_transform.requests == { |
| "bar": None, |
| "foo": None, |
| "other_param": None, |
| } |
| assert est.get_metadata_routing().fit_predict.requests == {"bar": None, "foo": None} |
|
|
| |
| est.set_fit_request(foo=True, bar="test") |
| with pytest.raises(ValueError, match="Conflicting metadata requests for"): |
| est.get_metadata_routing().fit_predict |
|
|
| |
| |
| est.set_predict_request(bar=True) |
| with pytest.raises(ValueError, match="Conflicting metadata requests for"): |
| est.get_metadata_routing().fit_predict |
|
|
| |
| |
| est.set_predict_request(foo=True, bar="test") |
| est.get_metadata_routing().fit_predict |
|
|
| |
| |
| est.set_transform_request(other_param=True) |
| assert est.get_metadata_routing().fit_transform.requests == { |
| "bar": "test", |
| "foo": True, |
| "other_param": True, |
| } |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_no_feature_flag_raises_error(): |
| """Test that when feature flag disabled, set_{method}_requests raises.""" |
| with config_context(enable_metadata_routing=False): |
| with pytest.raises(RuntimeError, match="This method is only available"): |
| ConsumingClassifier().set_fit_request(sample_weight=True) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_none_metadata_passed(): |
| """Test that passing None as metadata when not requested doesn't raise""" |
| MetaRegressor(estimator=ConsumingRegressor()).fit(X, y, sample_weight=None) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_no_metadata_always_works(): |
| """Test that when no metadata is passed, having a meta-estimator which does |
| not yet support metadata routing works. |
| |
| Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28246 |
| """ |
|
|
| class Estimator(_RoutingNotSupportedMixin, BaseEstimator): |
| def fit(self, X, y, metadata=None): |
| return self |
|
|
| |
| MetaRegressor(estimator=Estimator()).fit(X, y) |
| |
| with pytest.raises( |
| NotImplementedError, match="Estimator has not implemented metadata routing yet." |
| ): |
| MetaRegressor(estimator=Estimator()).fit(X, y, metadata=my_groups) |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_unsetmetadatapassederror_correct(): |
| """Test that UnsetMetadataPassedError raises the correct error message when |
| set_{method}_request is not set in nested cases.""" |
| weighted_meta = WeightedMetaClassifier(estimator=ConsumingClassifier()) |
| pipe = SimplePipeline([weighted_meta]) |
| msg = re.escape( |
| "[metadata] are passed but are not explicitly set as requested or not requested" |
| " for ConsumingClassifier.fit, which is used within WeightedMetaClassifier.fit." |
| " Call `ConsumingClassifier.set_fit_request({metadata}=True/False)` for each" |
| " metadata you want to request/ignore." |
| ) |
|
|
| with pytest.raises(UnsetMetadataPassedError, match=msg): |
| pipe.fit(X, y, metadata="blah") |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_unsetmetadatapassederror_correct_for_composite_methods(): |
| """Test that UnsetMetadataPassedError raises the correct error message when |
| composite metadata request methods are not set in nested cases.""" |
| consuming_transformer = ConsumingTransformer() |
| pipe = Pipeline([("consuming_transformer", consuming_transformer)]) |
|
|
| msg = re.escape( |
| "[metadata] are passed but are not explicitly set as requested or not requested" |
| " for ConsumingTransformer.fit_transform, which is used within" |
| " Pipeline.fit_transform. Call" |
| " `ConsumingTransformer.set_fit_request({metadata}=True/False)" |
| ".set_transform_request({metadata}=True/False)`" |
| " for each metadata you want to request/ignore." |
| ) |
| with pytest.raises(UnsetMetadataPassedError, match=msg): |
| pipe.fit_transform(X, y, metadata="blah") |
|
|
|
|
| @config_context(enable_metadata_routing=True) |
| def test_unbound_set_methods_work(): |
| """Tests that if the set_{method}_request is unbound, it still works. |
| |
| Also test that passing positional arguments to the set_{method}_request fails |
| with the right TypeError message. |
| |
| Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28632 |
| """ |
|
|
| class A(BaseEstimator): |
| def fit(self, X, y, sample_weight=None): |
| return self |
|
|
| error_message = re.escape( |
| "set_fit_request() takes 0 positional argument but 1 were given" |
| ) |
|
|
| |
| with pytest.raises(TypeError, match=error_message): |
| A().set_fit_request(True) |
|
|
| |
| |
| |
| A.set_fit_request = A.set_fit_request |
|
|
| |
| A().set_fit_request(sample_weight=True) |
|
|
| |
| with pytest.raises(TypeError, match=error_message): |
| A().set_fit_request(True) |
|
|