| import inspect |
| from collections import defaultdict |
| from functools import partial |
|
|
| import numpy as np |
| from numpy.testing import assert_array_equal |
|
|
| from sklearn.base import ( |
| BaseEstimator, |
| ClassifierMixin, |
| MetaEstimatorMixin, |
| RegressorMixin, |
| TransformerMixin, |
| clone, |
| ) |
| from sklearn.metrics._scorer import _Scorer, mean_squared_error |
| from sklearn.model_selection import BaseCrossValidator |
| from sklearn.model_selection._split import GroupsConsumerMixin |
| from sklearn.utils._metadata_requests import ( |
| SIMPLE_METHODS, |
| ) |
| from sklearn.utils.metadata_routing import ( |
| MetadataRouter, |
| MethodMapping, |
| process_routing, |
| ) |
| from sklearn.utils.multiclass import _check_partial_fit_first_call |
|
|
|
|
| def record_metadata(obj, record_default=True, **kwargs): |
| """Utility function to store passed metadata to a method of obj. |
| |
| If record_default is False, kwargs whose values are "default" are skipped. |
| This is so that checks on keyword arguments whose default was not changed |
| are skipped. |
| |
| """ |
| stack = inspect.stack() |
| callee = stack[1].function |
| caller = stack[2].function |
| if not hasattr(obj, "_records"): |
| obj._records = defaultdict(lambda: defaultdict(list)) |
| if not record_default: |
| kwargs = { |
| key: val |
| for key, val in kwargs.items() |
| if not isinstance(val, str) or (val != "default") |
| } |
| obj._records[callee][caller].append(kwargs) |
|
|
|
|
| def check_recorded_metadata(obj, method, parent, split_params=tuple(), **kwargs): |
| """Check whether the expected metadata is passed to the object's method. |
| |
| Parameters |
| ---------- |
| obj : estimator object |
| sub-estimator to check routed params for |
| method : str |
| sub-estimator's method where metadata is routed to, or otherwise in |
| the context of metadata routing referred to as 'callee' |
| parent : str |
| the parent method which should have called `method`, or otherwise in |
| the context of metadata routing referred to as 'caller' |
| split_params : tuple, default=empty |
| specifies any parameters which are to be checked as being a subset |
| of the original values |
| **kwargs : dict |
| passed metadata |
| """ |
| all_records = ( |
| getattr(obj, "_records", dict()).get(method, dict()).get(parent, list()) |
| ) |
| for record in all_records: |
| |
| |
| assert set(kwargs.keys()) == set( |
| record.keys() |
| ), f"Expected {kwargs.keys()} vs {record.keys()}" |
| for key, value in kwargs.items(): |
| recorded_value = record[key] |
| |
| |
| if key in split_params and recorded_value is not None: |
| assert np.isin(recorded_value, value).all() |
| else: |
| if isinstance(recorded_value, np.ndarray): |
| assert_array_equal(recorded_value, value) |
| else: |
| assert ( |
| recorded_value is value |
| ), f"Expected {recorded_value} vs {value}. Method: {method}" |
|
|
|
|
| record_metadata_not_default = partial(record_metadata, record_default=False) |
|
|
|
|
| def assert_request_is_empty(metadata_request, exclude=None): |
| """Check if a metadata request dict is empty. |
| |
| One can exclude a method or a list of methods from the check using the |
| ``exclude`` parameter. If metadata_request is a MetadataRouter, then |
| ``exclude`` can be of the form ``{"object" : [method, ...]}``. |
| """ |
| if isinstance(metadata_request, MetadataRouter): |
| for name, route_mapping in metadata_request: |
| if exclude is not None and name in exclude: |
| _exclude = exclude[name] |
| else: |
| _exclude = None |
| assert_request_is_empty(route_mapping.router, exclude=_exclude) |
| return |
|
|
| exclude = [] if exclude is None else exclude |
| for method in SIMPLE_METHODS: |
| if method in exclude: |
| continue |
| mmr = getattr(metadata_request, method) |
| props = [ |
| prop |
| for prop, alias in mmr.requests.items() |
| if isinstance(alias, str) or alias is not None |
| ] |
| assert not props |
|
|
|
|
| def assert_request_equal(request, dictionary): |
| for method, requests in dictionary.items(): |
| mmr = getattr(request, method) |
| assert mmr.requests == requests |
|
|
| empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary] |
| for method in empty_methods: |
| assert not len(getattr(request, method).requests) |
|
|
|
|
| class _Registry(list): |
| |
| |
| |
| |
| |
| def __deepcopy__(self, memo): |
| return self |
|
|
| def __copy__(self): |
| return self |
|
|
|
|
| class ConsumingRegressor(RegressorMixin, BaseEstimator): |
| """A regressor consuming metadata. |
| |
| Parameters |
| ---------- |
| registry : list, default=None |
| If a list, the estimator will append itself to the list in order to have |
| a reference to the estimator later on. Since that reference is not |
| required in all tests, registration can be skipped by leaving this value |
| as None. |
| """ |
|
|
| def __init__(self, registry=None): |
| self.registry = registry |
|
|
| def partial_fit(self, X, y, sample_weight="default", metadata="default"): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return self |
|
|
| def fit(self, X, y, sample_weight="default", metadata="default"): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return self |
|
|
| def predict(self, X, y=None, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return np.zeros(shape=(len(X),)) |
|
|
| def score(self, X, y, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return 1 |
|
|
|
|
| class NonConsumingClassifier(ClassifierMixin, BaseEstimator): |
| """A classifier which accepts no metadata on any method.""" |
|
|
| def __init__(self, alpha=0.0): |
| self.alpha = alpha |
|
|
| def fit(self, X, y): |
| self.classes_ = np.unique(y) |
| self.coef_ = np.ones_like(X) |
| return self |
|
|
| def partial_fit(self, X, y, classes=None): |
| return self |
|
|
| def decision_function(self, X): |
| return self.predict(X) |
|
|
| def predict(self, X): |
| y_pred = np.empty(shape=(len(X),)) |
| y_pred[: len(X) // 2] = 0 |
| y_pred[len(X) // 2 :] = 1 |
| return y_pred |
|
|
| def predict_proba(self, X): |
| |
| y_proba = np.empty(shape=(len(X), 2)) |
| y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) |
| y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) |
| return y_proba |
|
|
| def predict_log_proba(self, X): |
| |
| return self.predict_proba(X) |
|
|
|
|
| class NonConsumingRegressor(RegressorMixin, BaseEstimator): |
| """A classifier which accepts no metadata on any method.""" |
|
|
| def fit(self, X, y): |
| return self |
|
|
| def partial_fit(self, X, y): |
| return self |
|
|
| def predict(self, X): |
| return np.ones(len(X)) |
|
|
|
|
| class ConsumingClassifier(ClassifierMixin, BaseEstimator): |
| """A classifier consuming metadata. |
| |
| Parameters |
| ---------- |
| registry : list, default=None |
| If a list, the estimator will append itself to the list in order to have |
| a reference to the estimator later on. Since that reference is not |
| required in all tests, registration can be skipped by leaving this value |
| as None. |
| |
| alpha : float, default=0 |
| This parameter is only used to test the ``*SearchCV`` objects, and |
| doesn't do anything. |
| """ |
|
|
| def __init__(self, registry=None, alpha=0.0): |
| self.alpha = alpha |
| self.registry = registry |
|
|
| def partial_fit( |
| self, X, y, classes=None, sample_weight="default", metadata="default" |
| ): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| _check_partial_fit_first_call(self, classes) |
| return self |
|
|
| def fit(self, X, y, sample_weight="default", metadata="default"): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
|
|
| self.classes_ = np.unique(y) |
| self.coef_ = np.ones_like(X) |
| return self |
|
|
| def predict(self, X, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| y_score = np.empty(shape=(len(X),), dtype="int8") |
| y_score[len(X) // 2 :] = 0 |
| y_score[: len(X) // 2] = 1 |
| return y_score |
|
|
| def predict_proba(self, X, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| y_proba = np.empty(shape=(len(X), 2)) |
| y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) |
| y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) |
| return y_proba |
|
|
| def predict_log_proba(self, X, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return np.zeros(shape=(len(X), 2)) |
|
|
| def decision_function(self, X, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| y_score = np.empty(shape=(len(X),)) |
| y_score[len(X) // 2 :] = 0 |
| y_score[: len(X) // 2] = 1 |
| return y_score |
|
|
| def score(self, X, y, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return 1 |
|
|
|
|
| class ConsumingTransformer(TransformerMixin, BaseEstimator): |
| """A transformer which accepts metadata on fit and transform. |
| |
| Parameters |
| ---------- |
| registry : list, default=None |
| If a list, the estimator will append itself to the list in order to have |
| a reference to the estimator later on. Since that reference is not |
| required in all tests, registration can be skipped by leaving this value |
| as None. |
| """ |
|
|
| def __init__(self, registry=None): |
| self.registry = registry |
|
|
| def fit(self, X, y=None, sample_weight="default", metadata="default"): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| self.fitted_ = True |
| return self |
|
|
| def transform(self, X, sample_weight="default", metadata="default"): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return X + 1 |
|
|
| def fit_transform(self, X, y, sample_weight="default", metadata="default"): |
| |
| |
| |
| |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform( |
| X, sample_weight=sample_weight, metadata=metadata |
| ) |
|
|
| def inverse_transform(self, X, sample_weight=None, metadata=None): |
| record_metadata_not_default( |
| self, sample_weight=sample_weight, metadata=metadata |
| ) |
| return X - 1 |
|
|
|
|
| class ConsumingNoFitTransformTransformer(BaseEstimator): |
| """A metadata consuming transformer that doesn't inherit from |
| TransformerMixin, and thus doesn't implement `fit_transform`. Note that |
| TransformerMixin's `fit_transform` doesn't route metadata to `transform`.""" |
|
|
| def __init__(self, registry=None): |
| self.registry = registry |
|
|
| def fit(self, X, y=None, sample_weight=None, metadata=None): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata(self, sample_weight=sample_weight, metadata=metadata) |
|
|
| return self |
|
|
| def transform(self, X, sample_weight=None, metadata=None): |
| record_metadata(self, sample_weight=sample_weight, metadata=metadata) |
| return X |
|
|
|
|
| class ConsumingScorer(_Scorer): |
| def __init__(self, registry=None): |
| super().__init__( |
| score_func=mean_squared_error, sign=1, kwargs={}, response_method="predict" |
| ) |
| self.registry = registry |
|
|
| def _score(self, method_caller, clf, X, y, **kwargs): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default(self, **kwargs) |
|
|
| sample_weight = kwargs.get("sample_weight", None) |
| return super()._score(method_caller, clf, X, y, sample_weight=sample_weight) |
|
|
|
|
| class ConsumingSplitter(GroupsConsumerMixin, BaseCrossValidator): |
| def __init__(self, registry=None): |
| self.registry = registry |
|
|
| def split(self, X, y=None, groups="default", metadata="default"): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata_not_default(self, groups=groups, metadata=metadata) |
|
|
| split_index = len(X) // 2 |
| train_indices = list(range(0, split_index)) |
| test_indices = list(range(split_index, len(X))) |
| yield test_indices, train_indices |
| yield train_indices, test_indices |
|
|
| def get_n_splits(self, X=None, y=None, groups=None, metadata=None): |
| return 2 |
|
|
| def _iter_test_indices(self, X=None, y=None, groups=None): |
| split_index = len(X) // 2 |
| train_indices = list(range(0, split_index)) |
| test_indices = list(range(split_index, len(X))) |
| yield test_indices |
| yield train_indices |
|
|
|
|
| class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): |
| """A meta-regressor which is only a router.""" |
|
|
| def __init__(self, estimator): |
| self.estimator = estimator |
|
|
| def fit(self, X, y, **fit_params): |
| params = process_routing(self, "fit", **fit_params) |
| self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) |
|
|
| def get_metadata_routing(self): |
| router = MetadataRouter(owner=self.__class__.__name__).add( |
| estimator=self.estimator, |
| method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
| ) |
| return router |
|
|
|
|
| class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): |
| """A meta-regressor which is also a consumer.""" |
|
|
| def __init__(self, estimator, registry=None): |
| self.estimator = estimator |
| self.registry = registry |
|
|
| def fit(self, X, y, sample_weight=None, **fit_params): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata(self, sample_weight=sample_weight) |
| params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params) |
| self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) |
| return self |
|
|
| def predict(self, X, **predict_params): |
| params = process_routing(self, "predict", **predict_params) |
| return self.estimator_.predict(X, **params.estimator.predict) |
|
|
| def get_metadata_routing(self): |
| router = ( |
| MetadataRouter(owner=self.__class__.__name__) |
| .add_self_request(self) |
| .add( |
| estimator=self.estimator, |
| method_mapping=MethodMapping() |
| .add(caller="fit", callee="fit") |
| .add(caller="predict", callee="predict"), |
| ) |
| ) |
| return router |
|
|
|
|
| class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): |
| """A meta-estimator which also consumes sample_weight itself in ``fit``.""" |
|
|
| def __init__(self, estimator, registry=None): |
| self.estimator = estimator |
| self.registry = registry |
|
|
| def fit(self, X, y, sample_weight=None, **kwargs): |
| if self.registry is not None: |
| self.registry.append(self) |
|
|
| record_metadata(self, sample_weight=sample_weight) |
| params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs) |
| self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) |
| return self |
|
|
| def get_metadata_routing(self): |
| router = ( |
| MetadataRouter(owner=self.__class__.__name__) |
| .add_self_request(self) |
| .add( |
| estimator=self.estimator, |
| method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
| ) |
| ) |
| return router |
|
|
|
|
| class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): |
| """A simple meta-transformer.""" |
|
|
| def __init__(self, transformer): |
| self.transformer = transformer |
|
|
| def fit(self, X, y=None, **fit_params): |
| params = process_routing(self, "fit", **fit_params) |
| self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) |
| return self |
|
|
| def transform(self, X, y=None, **transform_params): |
| params = process_routing(self, "transform", **transform_params) |
| return self.transformer_.transform(X, **params.transformer.transform) |
|
|
| def get_metadata_routing(self): |
| return MetadataRouter(owner=self.__class__.__name__).add( |
| transformer=self.transformer, |
| method_mapping=MethodMapping() |
| .add(caller="fit", callee="fit") |
| .add(caller="transform", callee="transform"), |
| ) |
|
|