Spaces:
Build error
Build error
| import time | |
| import warnings | |
| from typing import Any, Callable, Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from captum._utils.models.linear_model.model import LinearModel | |
| from torch.utils.data import DataLoader | |
| def l2_loss(x1, x2, weights=None): | |
| if weights is None: | |
| return torch.mean((x1 - x2) ** 2) / 2.0 | |
| else: | |
| return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0 | |
| def sgd_train_linear_model( | |
| model: LinearModel, | |
| dataloader: DataLoader, | |
| construct_kwargs: Dict[str, Any], | |
| max_epoch: int = 100, | |
| reduce_lr: bool = True, | |
| initial_lr: float = 0.01, | |
| alpha: float = 1.0, | |
| loss_fn: Callable = l2_loss, | |
| reg_term: Optional[int] = 1, | |
| patience: int = 10, | |
| threshold: float = 1e-4, | |
| running_loss_window: Optional[int] = None, | |
| device: Optional[str] = None, | |
| init_scheme: str = "zeros", | |
| debug: bool = False, | |
| ) -> Dict[str, float]: | |
| r""" | |
| Trains a linear model with SGD. This will continue to iterate your | |
| dataloader until we converged to a solution or alternatively until we have | |
| exhausted `max_epoch`. | |
| Convergence is defined by the loss not changing by `threshold` amount for | |
| `patience` number of iterations. | |
| Args: | |
| model | |
| The model to train | |
| dataloader | |
| The data to train it with. We will assume the dataloader produces | |
| either pairs or triples of the form (x, y) or (x, y, w). Where x and | |
| y are typical pairs for supervised learning and w is a weight | |
| vector. | |
| We will call `model._construct_model_params` with construct_kwargs | |
| and the input features set to `x.shape[1]` (`x.shape[0]` corresponds | |
| to the batch size). We assume that `len(x.shape) == 2`, i.e. the | |
| tensor is flat. The number of output features will be set to | |
| y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape) | |
| <= 2`. | |
| max_epoch | |
| The maximum number of epochs to exhaust | |
| reduce_lr | |
| Whether or not to reduce the learning rate as iterations progress. | |
| Halves the learning rate when the training loss does not move. This | |
| uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the | |
| parameters `patience` and `threshold` | |
| initial_lr | |
| The initial learning rate to use. | |
| alpha | |
| A constant for the regularization term. | |
| loss_fn | |
| The loss to optimise for. This must accept three parameters: | |
| x1 (predicted), x2 (labels) and a weight vector | |
| reg_term | |
| Regularization is defined by the `reg_term` norm of the weights. | |
| Please use `None` if you do not wish to use regularization. | |
| patience | |
| Defines the number of iterations in a row the loss must remain | |
| within `threshold` in order to be classified as converged. | |
| threshold | |
| Threshold for convergence detection. | |
| running_loss_window | |
| Used to report the training loss once we have finished training and | |
| to determine when we have converged (along with reducing the | |
| learning rate). | |
| The reported training loss will take the last `running_loss_window` | |
| iterations and average them. | |
| If `None` we will approximate this to be the number of examples in | |
| an epoch. | |
| init_scheme | |
| Initialization to use prior to training the linear model. | |
| device | |
| The device to send the model and data to. If None then no `.to` call | |
| will be used. | |
| debug | |
| Whether to print the loss, learning rate per iteration | |
| Returns | |
| This will return the final training loss (averaged with | |
| `running_loss_window`) | |
| """ | |
| loss_window: List[torch.Tensor] = [] | |
| min_avg_loss = None | |
| convergence_counter = 0 | |
| converged = False | |
| def get_point(datapoint): | |
| if len(datapoint) == 2: | |
| x, y = datapoint | |
| w = None | |
| else: | |
| x, y, w = datapoint | |
| if device is not None: | |
| x = x.to(device) | |
| y = y.to(device) | |
| if w is not None: | |
| w = w.to(device) | |
| return x, y, w | |
| # get a point and construct the model | |
| data_iter = iter(dataloader) | |
| x, y, w = get_point(next(data_iter)) | |
| model._construct_model_params( | |
| in_features=x.shape[1], | |
| out_features=y.shape[1] if len(y.shape) == 2 else 1, | |
| **construct_kwargs, | |
| ) | |
| model.train() | |
| assert model.linear is not None | |
| if init_scheme is not None: | |
| assert init_scheme in ["xavier", "zeros"] | |
| with torch.no_grad(): | |
| if init_scheme == "xavier": | |
| torch.nn.init.xavier_uniform_(model.linear.weight) | |
| else: | |
| model.linear.weight.zero_() | |
| if model.linear.bias is not None: | |
| model.linear.bias.zero_() | |
| optim = torch.optim.SGD(model.parameters(), lr=initial_lr) | |
| if reduce_lr: | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optim, factor=0.5, patience=patience, threshold=threshold | |
| ) | |
| t1 = time.time() | |
| epoch = 0 | |
| i = 0 | |
| while epoch < max_epoch: | |
| while True: # for x, y, w in dataloader | |
| if running_loss_window is None: | |
| running_loss_window = x.shape[0] * len(dataloader) | |
| y = y.view(x.shape[0], -1) | |
| if w is not None: | |
| w = w.view(x.shape[0], -1) | |
| i += 1 | |
| out = model(x) | |
| loss = loss_fn(y, out, w) | |
| if reg_term is not None: | |
| reg = torch.norm(model.linear.weight, p=reg_term) | |
| loss += reg.sum() * alpha | |
| if len(loss_window) >= running_loss_window: | |
| loss_window = loss_window[1:] | |
| loss_window.append(loss.clone().detach()) | |
| assert len(loss_window) <= running_loss_window | |
| average_loss = torch.mean(torch.stack(loss_window)) | |
| if min_avg_loss is not None: | |
| # if we haven't improved by at least `threshold` | |
| if average_loss > min_avg_loss or torch.isclose( | |
| min_avg_loss, average_loss, atol=threshold | |
| ): | |
| convergence_counter += 1 | |
| if convergence_counter >= patience: | |
| converged = True | |
| break | |
| else: | |
| convergence_counter = 0 | |
| if min_avg_loss is None or min_avg_loss >= average_loss: | |
| min_avg_loss = average_loss.clone() | |
| if debug: | |
| print( | |
| f"lr={optim.param_groups[0]['lr']}, Loss={loss}," | |
| + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" | |
| ) | |
| loss.backward() | |
| optim.step() | |
| model.zero_grad() | |
| if scheduler: | |
| scheduler.step(average_loss) | |
| temp = next(data_iter, None) | |
| if temp is None: | |
| break | |
| x, y, w = get_point(temp) | |
| if converged: | |
| break | |
| epoch += 1 | |
| data_iter = iter(dataloader) | |
| x, y, w = get_point(next(data_iter)) | |
| t2 = time.time() | |
| return { | |
| "train_time": t2 - t1, | |
| "train_loss": torch.mean(torch.stack(loss_window)).item(), | |
| "train_iter": i, | |
| "train_epoch": epoch, | |
| } | |
| class NormLayer(nn.Module): | |
| def __init__(self, mean, std, n=None, eps=1e-8) -> None: | |
| super().__init__() | |
| self.mean = mean | |
| self.std = std | |
| self.eps = eps | |
| def forward(self, x): | |
| return (x - self.mean) / (self.std + self.eps) | |
| def sklearn_train_linear_model( | |
| model: LinearModel, | |
| dataloader: DataLoader, | |
| construct_kwargs: Dict[str, Any], | |
| sklearn_trainer: str = "Lasso", | |
| norm_input: bool = False, | |
| **fit_kwargs, | |
| ): | |
| r""" | |
| Alternative method to train with sklearn. This does introduce some slight | |
| overhead as we convert the tensors to numpy and then convert the resulting | |
| trained model to a `LinearModel` object. However, this conversion | |
| should be negligible. | |
| Please note that this assumes: | |
| 0. You have sklearn and numpy installed | |
| 1. The dataset can fit into memory | |
| Args | |
| model | |
| The model to train. | |
| dataloader | |
| The data to use. This will be exhausted and converted to numpy | |
| arrays. Therefore please do not feed an infinite dataloader. | |
| norm_input | |
| Whether or not to normalize the input | |
| sklearn_trainer | |
| The sklearn model to use to train the model. Please refer to | |
| sklearn.linear_model for a list of modules to use. | |
| construct_kwargs | |
| Additional arguments provided to the `sklearn_trainer` constructor | |
| fit_kwargs | |
| Other arguments to send to `sklearn_trainer`'s `.fit` method | |
| """ | |
| from functools import reduce | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| raise ValueError("numpy is not available. Please install numpy.") | |
| try: | |
| import sklearn | |
| import sklearn.linear_model | |
| import sklearn.svm | |
| except ImportError: | |
| raise ValueError("sklearn is not available. Please install sklearn >= 0.23") | |
| if not sklearn.__version__ >= "0.23.0": | |
| warnings.warn( | |
| "Must have sklearn version 0.23.0 or higher to use " | |
| "sample_weight in Lasso regression." | |
| ) | |
| num_batches = 0 | |
| xs, ys, ws = [], [], [] | |
| for data in dataloader: | |
| if len(data) == 3: | |
| x, y, w = data | |
| else: | |
| assert len(data) == 2 | |
| x, y = data | |
| w = None | |
| xs.append(x.cpu().numpy()) | |
| ys.append(y.cpu().numpy()) | |
| if w is not None: | |
| ws.append(w.cpu().numpy()) | |
| num_batches += 1 | |
| x = np.concatenate(xs, axis=0) | |
| y = np.concatenate(ys, axis=0) | |
| if len(ws) > 0: | |
| w = np.concatenate(ws, axis=0) | |
| else: | |
| w = None | |
| if norm_input: | |
| mean, std = x.mean(0), x.std(0) | |
| x -= mean | |
| x /= std | |
| t1 = time.time() | |
| sklearn_model = reduce( | |
| lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") | |
| )(**construct_kwargs) | |
| try: | |
| sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs) | |
| except TypeError: | |
| sklearn_model.fit(x, y, **fit_kwargs) | |
| warnings.warn( | |
| "Sample weight is not supported for the provided linear model!" | |
| " Trained model without weighting inputs. For Lasso, please" | |
| " upgrade sklearn to a version >= 0.23.0." | |
| ) | |
| t2 = time.time() | |
| # Convert weights to pytorch | |
| classes = ( | |
| torch.IntTensor(sklearn_model.classes_) | |
| if hasattr(sklearn_model, "classes_") | |
| else None | |
| ) | |
| # extract model device | |
| device = model.device if hasattr(model, "device") else "cpu" | |
| num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1 | |
| weight_values = torch.FloatTensor(sklearn_model.coef_).to(device) # type: ignore | |
| bias_values = torch.FloatTensor([sklearn_model.intercept_]).to( # type: ignore | |
| device # type: ignore | |
| ) # type: ignore | |
| model._construct_model_params( | |
| norm_type=None, | |
| weight_values=weight_values.view(num_outputs, -1), | |
| bias_value=bias_values.squeeze().unsqueeze(0), | |
| classes=classes, | |
| ) | |
| if norm_input: | |
| model.norm = NormLayer(mean, std) | |
| return {"train_time": t2 - t1} | |