Spaces:
Build error
Build error
| from typing import Callable, cast, List, Optional | |
| import torch.nn as nn | |
| from captum._utils.models.model import Model | |
| from torch import Tensor | |
| from torch.utils.data import DataLoader | |
| class LinearModel(nn.Module, Model): | |
| SUPPORTED_NORMS: List[Optional[str]] = [None, "batch_norm", "layer_norm"] | |
| def __init__(self, train_fn: Callable, **kwargs) -> None: | |
| r""" | |
| Constructs a linear model with a training function and additional | |
| construction arguments that will be sent to | |
| `self._construct_model_params` after a `self.fit` is called. Please note | |
| that this assumes the `self.train_fn` will call | |
| `self._construct_model_params`. | |
| Please note that this is an experimental feature. | |
| Args: | |
| train_fn (callable) | |
| The function to train with. See | |
| `captum._utils.models.linear_model.train.sgd_train_linear_model` | |
| and | |
| `captum._utils.models.linear_model.train.sklearn_train_linear_model` | |
| for examples | |
| kwargs | |
| Any additional keyword arguments to send to | |
| `self._construct_model_params` once a `self.fit` is called. | |
| """ | |
| super().__init__() | |
| self.norm: Optional[nn.Module] = None | |
| self.linear: Optional[nn.Linear] = None | |
| self.train_fn = train_fn | |
| self.construct_kwargs = kwargs | |
| def _construct_model_params( | |
| self, | |
| in_features: Optional[int] = None, | |
| out_features: Optional[int] = None, | |
| norm_type: Optional[str] = None, | |
| affine_norm: bool = False, | |
| bias: bool = True, | |
| weight_values: Optional[Tensor] = None, | |
| bias_value: Optional[Tensor] = None, | |
| classes: Optional[Tensor] = None, | |
| ): | |
| r""" | |
| Lazily initializes a linear model. This will be called for you in a | |
| train method. | |
| Args: | |
| in_features (int): | |
| The number of input features | |
| output_features (int): | |
| The number of output features. | |
| norm_type (str, optional): | |
| The type of normalization that can occur. Please assign this | |
| to one of `PyTorchLinearModel.SUPPORTED_NORMS`. | |
| affine_norm (bool): | |
| Whether or not to learn an affine transformation of the | |
| normalization parameters used. | |
| bias (bool): | |
| Whether to add a bias term. Not needed if normalized input. | |
| weight_values (tensor, optional): | |
| The values to initialize the linear model with. This must be a | |
| 1D or 2D tensor, and of the form `(num_outputs, num_features)` or | |
| `(num_features,)`. Additionally, if this is provided you need not | |
| to provide `in_features` or `out_features`. | |
| bias_value (tensor, optional): | |
| The bias value to initialize the model with. | |
| classes (tensor, optional): | |
| The list of prediction classes supported by the model in case it | |
| performs classificaton. In case of regression it is set to None. | |
| Default: None | |
| """ | |
| if norm_type not in LinearModel.SUPPORTED_NORMS: | |
| raise ValueError( | |
| f"{norm_type} not supported. Please use {LinearModel.SUPPORTED_NORMS}" | |
| ) | |
| if weight_values is not None: | |
| in_features = weight_values.shape[-1] | |
| out_features = ( | |
| 1 if len(weight_values.shape) == 1 else weight_values.shape[0] | |
| ) | |
| if in_features is None or out_features is None: | |
| raise ValueError( | |
| "Please provide `in_features` and `out_features` or `weight_values`" | |
| ) | |
| if norm_type == "batch_norm": | |
| self.norm = nn.BatchNorm1d(in_features, eps=1e-8, affine=affine_norm) | |
| elif norm_type == "layer_norm": | |
| self.norm = nn.LayerNorm( | |
| in_features, eps=1e-8, elementwise_affine=affine_norm | |
| ) | |
| else: | |
| self.norm = None | |
| self.linear = nn.Linear(in_features, out_features, bias=bias) | |
| if weight_values is not None: | |
| self.linear.weight.data = weight_values | |
| if bias_value is not None: | |
| if not bias: | |
| raise ValueError("`bias_value` is not None and bias is False") | |
| self.linear.bias.data = bias_value | |
| if classes is not None: | |
| self.linear.classes = classes | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| r""" | |
| Calls `self.train_fn` | |
| """ | |
| return self.train_fn( | |
| self, | |
| dataloader=train_data, | |
| construct_kwargs=self.construct_kwargs, | |
| **kwargs, | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| assert self.linear is not None | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| return self.linear(x) | |
| def representation(self) -> Tensor: | |
| r""" | |
| Returns a tensor which describes the hyper-plane input space. This does | |
| not include the bias. For bias/intercept, please use `self.bias` | |
| """ | |
| assert self.linear is not None | |
| return self.linear.weight.detach() | |
| def bias(self) -> Optional[Tensor]: | |
| r""" | |
| Returns the bias of the linear model | |
| """ | |
| if self.linear is None or self.linear.bias is None: | |
| return None | |
| return self.linear.bias.detach() | |
| def classes(self) -> Optional[Tensor]: | |
| if self.linear is None or self.linear.classes is None: | |
| return None | |
| return cast(Tensor, self.linear.classes).detach() | |
| class SGDLinearModel(LinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class. Construct a a `LinearModel` with the | |
| `sgd_train_linear_model` as the train method | |
| Args: | |
| kwargs | |
| Arguments send to `self._construct_model_params` after | |
| `self.fit` is called. Please refer to that method for parameter | |
| documentation. | |
| """ | |
| # avoid cycles | |
| from captum._utils.models.linear_model.train import sgd_train_linear_model | |
| super().__init__(train_fn=sgd_train_linear_model, **kwargs) | |
| class SGDLasso(SGDLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class to train a `LinearModel` with SGD | |
| (`sgd_train_linear_model`) whilst setting appropriate parameters to | |
| optimize for ridge regression loss. This optimizes L2 loss + alpha * L1 | |
| regularization. | |
| Please note that with SGD it is not guaranteed that weights will | |
| converge to 0. | |
| """ | |
| super().__init__(**kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| # avoid cycles | |
| from captum._utils.models.linear_model.train import l2_loss | |
| return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=1, **kwargs) | |
| class SGDRidge(SGDLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class to train a `LinearModel` with SGD | |
| (`sgd_train_linear_model`) whilst setting appropriate parameters to | |
| optimize for ridge regression loss. This optimizes L2 loss + alpha * | |
| L2 regularization. | |
| """ | |
| super().__init__(**kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| # avoid cycles | |
| from captum._utils.models.linear_model.train import l2_loss | |
| return super().fit(train_data=train_data, loss_fn=l2_loss, reg_term=2, **kwargs) | |
| class SGDLinearRegression(SGDLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class to train a `LinearModel` with SGD | |
| (`sgd_train_linear_model`). For linear regression this assigns the loss | |
| to L2 and no regularization. | |
| """ | |
| super().__init__(**kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| # avoid cycles | |
| from captum._utils.models.linear_model.train import l2_loss | |
| return super().fit( | |
| train_data=train_data, loss_fn=l2_loss, reg_term=None, **kwargs | |
| ) | |
| class SkLearnLinearModel(LinearModel): | |
| def __init__(self, sklearn_module: str, **kwargs) -> None: | |
| r""" | |
| Factory class to construct a `LinearModel` with sklearn training method. | |
| Please note that this assumes: | |
| 0. You have sklearn and numpy installed | |
| 1. The dataset can fit into memory | |
| SkLearn support does introduce some slight overhead as we convert the | |
| tensors to numpy and then convert the resulting trained model to a | |
| `LinearModel` object. However, this conversion should be negligible. | |
| Args: | |
| sklearn_module | |
| The module under sklearn to construct and use for training, e.g. | |
| use "svm.LinearSVC" for an SVM or "linear_model.Lasso" for Lasso. | |
| There are factory classes defined for you for common use cases, | |
| such as `SkLearnLasso`. | |
| kwargs | |
| The kwargs to pass to the construction of the sklearn model | |
| """ | |
| # avoid cycles | |
| from captum._utils.models.linear_model.train import sklearn_train_linear_model | |
| super().__init__(train_fn=sklearn_train_linear_model, **kwargs) | |
| self.sklearn_module = sklearn_module | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| r""" | |
| Args: | |
| train_data | |
| Train data to use | |
| kwargs | |
| Arguments to feed to `.fit` method for sklearn | |
| """ | |
| return super().fit( | |
| train_data=train_data, sklearn_trainer=self.sklearn_module, **kwargs | |
| ) | |
| class SkLearnLasso(SkLearnLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class. Trains a `LinearModel` model with | |
| `sklearn.linear_model.Lasso`. You will need sklearn version >= 0.23 to | |
| support sample weights. | |
| """ | |
| super().__init__(sklearn_module="linear_model.Lasso", **kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| return super().fit(train_data=train_data, **kwargs) | |
| class SkLearnRidge(SkLearnLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class. Trains a model with `sklearn.linear_model.Ridge`. | |
| Any arguments provided to the sklearn constructor can be provided | |
| as kwargs here. | |
| """ | |
| super().__init__(sklearn_module="linear_model.Ridge", **kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| return super().fit(train_data=train_data, **kwargs) | |
| class SkLearnLinearRegression(SkLearnLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class. Trains a model with `sklearn.linear_model.LinearRegression`. | |
| Any arguments provided to the sklearn constructor can be provided | |
| as kwargs here. | |
| """ | |
| super().__init__(sklearn_module="linear_model.LinearRegression", **kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| return super().fit(train_data=train_data, **kwargs) | |
| class SkLearnLogisticRegression(SkLearnLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class. Trains a model with `sklearn.linear_model.LogisticRegression`. | |
| Any arguments provided to the sklearn constructor can be provided | |
| as kwargs here. | |
| """ | |
| super().__init__(sklearn_module="linear_model.LogisticRegression", **kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| return super().fit(train_data=train_data, **kwargs) | |
| class SkLearnSGDClassifier(SkLearnLinearModel): | |
| def __init__(self, **kwargs) -> None: | |
| r""" | |
| Factory class. Trains a model with `sklearn.linear_model.SGDClassifier(`. | |
| Any arguments provided to the sklearn constructor can be provided | |
| as kwargs here. | |
| """ | |
| super().__init__(sklearn_module="linear_model.SGDClassifier", **kwargs) | |
| def fit(self, train_data: DataLoader, **kwargs): | |
| return super().fit(train_data=train_data, **kwargs) | |