Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import TensorDataset, DataLoader | |
| import numpy as np | |
| from .models.encoder import TSEncoder | |
| from .models.losses import hierarchical_contrastive_loss | |
| from .utils import ( | |
| take_per_row, | |
| split_with_nan, | |
| centerize_vary_length_series, | |
| torch_pad_nan, | |
| ) | |
| class TS2Vec: | |
| """The TS2Vec model""" | |
| def __init__( | |
| self, | |
| input_dims, | |
| output_dims=320, | |
| hidden_dims=64, | |
| depth=10, | |
| device="cuda", | |
| lr=0.001, | |
| batch_size=16, | |
| max_train_length=None, | |
| temporal_unit=0, | |
| after_iter_callback=None, | |
| after_epoch_callback=None, | |
| ): | |
| """Initialize a TS2Vec model. | |
| Args: | |
| input_dims (int): The input dimension. For a univariate time series, this should be set to 1. | |
| output_dims (int): The representation dimension. | |
| hidden_dims (int): The hidden dimension of the encoder. | |
| depth (int): The number of hidden residual blocks in the encoder. | |
| device (int): The gpu used for training and inference. | |
| lr (int): The learning rate. | |
| batch_size (int): The batch size. | |
| max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length>. | |
| temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. | |
| after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration. | |
| after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch. | |
| """ | |
| super().__init__() | |
| self.device = device | |
| self.lr = lr | |
| self.batch_size = batch_size | |
| self.max_train_length = max_train_length | |
| self.temporal_unit = temporal_unit | |
| self._net = TSEncoder( | |
| input_dims=input_dims, | |
| output_dims=output_dims, | |
| hidden_dims=hidden_dims, | |
| depth=depth, | |
| ).to(self.device) | |
| self.net = torch.optim.swa_utils.AveragedModel(self._net) | |
| self.net.update_parameters(self._net) | |
| self.after_iter_callback = after_iter_callback | |
| self.after_epoch_callback = after_epoch_callback | |
| self.n_epochs = 0 | |
| self.n_iters = 0 | |
| def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): | |
| """Training the TS2Vec model. | |
| Args: | |
| train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. | |
| n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops. | |
| n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise. | |
| verbose (bool): Whether to print the training loss after each epoch. | |
| Returns: | |
| loss_log: a list containing the training losses on each epoch. | |
| """ | |
| assert train_data.ndim == 3 | |
| if n_iters is None and n_epochs is None: | |
| n_iters = ( | |
| 200 if train_data.size <= 100000 else 600 | |
| ) # default param for n_iters | |
| if self.max_train_length is not None: | |
| sections = train_data.shape[1] // self.max_train_length | |
| if sections >= 2: | |
| train_data = np.concatenate( | |
| split_with_nan(train_data, sections, axis=1), axis=0 | |
| ) | |
| temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0) | |
| if temporal_missing[0] or temporal_missing[-1]: | |
| train_data = centerize_vary_length_series(train_data) | |
| train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] | |
| train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=min(self.batch_size, len(train_dataset)), | |
| shuffle=True, | |
| drop_last=True, | |
| ) | |
| optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) | |
| loss_log = [] | |
| while True: | |
| if n_epochs is not None and self.n_epochs >= n_epochs: | |
| break | |
| cum_loss = 0 | |
| n_epoch_iters = 0 | |
| interrupted = False | |
| for batch in train_loader: | |
| if n_iters is not None and self.n_iters >= n_iters: | |
| interrupted = True | |
| break | |
| x = batch[0] | |
| if ( | |
| self.max_train_length is not None | |
| and x.size(1) > self.max_train_length | |
| ): | |
| window_offset = np.random.randint( | |
| x.size(1) - self.max_train_length + 1 | |
| ) | |
| x = x[:, window_offset : window_offset + self.max_train_length] | |
| x = x.to(self.device) | |
| ts_l = x.size(1) | |
| crop_l = np.random.randint( | |
| low=2 ** (self.temporal_unit + 1), high=ts_l + 1 | |
| ) | |
| crop_left = np.random.randint(ts_l - crop_l + 1) | |
| crop_right = crop_left + crop_l | |
| crop_eleft = np.random.randint(crop_left + 1) | |
| crop_eright = np.random.randint(low=crop_right, high=ts_l + 1) | |
| crop_offset = np.random.randint( | |
| low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0) | |
| ) | |
| optimizer.zero_grad() | |
| out1 = self._net( | |
| take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft) | |
| ) | |
| out1 = out1[:, -crop_l:] | |
| out2 = self._net( | |
| take_per_row(x, crop_offset + crop_left, crop_eright - crop_left) | |
| ) | |
| out2 = out2[:, :crop_l] | |
| loss = hierarchical_contrastive_loss( | |
| out1, out2, temporal_unit=self.temporal_unit | |
| ) | |
| loss.backward() | |
| optimizer.step() | |
| self.net.update_parameters(self._net) | |
| cum_loss += loss.item() | |
| n_epoch_iters += 1 | |
| self.n_iters += 1 | |
| if self.after_iter_callback is not None: | |
| self.after_iter_callback(self, loss.item()) | |
| if interrupted: | |
| break | |
| cum_loss /= n_epoch_iters | |
| loss_log.append(cum_loss) | |
| if verbose: | |
| print(f"Epoch #{self.n_epochs}: loss={cum_loss}") | |
| self.n_epochs += 1 | |
| if self.after_epoch_callback is not None: | |
| self.after_epoch_callback(self, cum_loss) | |
| return loss_log | |
| def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None): | |
| out = self.net(x.to(self.device, non_blocking=True), mask) | |
| if encoding_window == "full_series": | |
| if slicing is not None: | |
| out = out[:, slicing] | |
| out = F.max_pool1d( | |
| out.transpose(1, 2), | |
| kernel_size=out.size(1), | |
| ).transpose(1, 2) | |
| elif isinstance(encoding_window, int): | |
| out = F.max_pool1d( | |
| out.transpose(1, 2), | |
| kernel_size=encoding_window, | |
| stride=1, | |
| padding=encoding_window // 2, | |
| ).transpose(1, 2) | |
| if encoding_window % 2 == 0: | |
| out = out[:, :-1] | |
| if slicing is not None: | |
| out = out[:, slicing] | |
| elif encoding_window == "multiscale": | |
| p = 0 | |
| reprs = [] | |
| while (1 << p) + 1 < out.size(1): | |
| t_out = F.max_pool1d( | |
| out.transpose(1, 2), | |
| kernel_size=(1 << (p + 1)) + 1, | |
| stride=1, | |
| padding=1 << p, | |
| ).transpose(1, 2) | |
| if slicing is not None: | |
| t_out = t_out[:, slicing] | |
| reprs.append(t_out) | |
| p += 1 | |
| out = torch.cat(reprs, dim=-1) | |
| else: | |
| if slicing is not None: | |
| out = out[:, slicing] | |
| return out.cpu() | |
| def encode( | |
| self, | |
| data, | |
| mask=None, | |
| encoding_window=None, | |
| casual=False, | |
| sliding_length=None, | |
| sliding_padding=0, | |
| batch_size=None, | |
| ): | |
| """Compute representations using the model. | |
| Args: | |
| data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. | |
| mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'. | |
| encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size. | |
| casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp. | |
| sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series. | |
| sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows. | |
| batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training. | |
| Returns: | |
| repr: The representations for data. | |
| """ | |
| assert self.net is not None, "please train or load a net first" | |
| assert data.ndim == 3 | |
| if batch_size is None: | |
| batch_size = self.batch_size | |
| n_samples, ts_l, _ = data.shape | |
| org_training = self.net.training | |
| self.net.eval() | |
| dataset = TensorDataset(torch.from_numpy(data).to(torch.float)) | |
| loader = DataLoader(dataset, batch_size=batch_size) | |
| with torch.no_grad(): | |
| output = [] | |
| for batch in loader: | |
| x = batch[0] | |
| if sliding_length is not None: | |
| reprs = [] | |
| if n_samples < batch_size: | |
| calc_buffer = [] | |
| calc_buffer_l = 0 | |
| for i in range(0, ts_l, sliding_length): | |
| l = i - sliding_padding | |
| r = i + sliding_length + (sliding_padding if not casual else 0) | |
| x_sliding = torch_pad_nan( | |
| x[:, max(l, 0) : min(r, ts_l)], | |
| left=-l if l < 0 else 0, | |
| right=r - ts_l if r > ts_l else 0, | |
| dim=1, | |
| ) | |
| if n_samples < batch_size: | |
| if calc_buffer_l + n_samples > batch_size: | |
| out = self._eval_with_pooling( | |
| torch.cat(calc_buffer, dim=0), | |
| mask, | |
| slicing=slice( | |
| sliding_padding, | |
| sliding_padding + sliding_length, | |
| ), | |
| encoding_window=encoding_window, | |
| ) | |
| reprs += torch.split(out, n_samples) | |
| calc_buffer = [] | |
| calc_buffer_l = 0 | |
| calc_buffer.append(x_sliding) | |
| calc_buffer_l += n_samples | |
| else: | |
| out = self._eval_with_pooling( | |
| x_sliding, | |
| mask, | |
| slicing=slice( | |
| sliding_padding, sliding_padding + sliding_length | |
| ), | |
| encoding_window=encoding_window, | |
| ) | |
| reprs.append(out) | |
| if n_samples < batch_size: | |
| if calc_buffer_l > 0: | |
| out = self._eval_with_pooling( | |
| torch.cat(calc_buffer, dim=0), | |
| mask, | |
| slicing=slice( | |
| sliding_padding, sliding_padding + sliding_length | |
| ), | |
| encoding_window=encoding_window, | |
| ) | |
| reprs += torch.split(out, n_samples) | |
| calc_buffer = [] | |
| calc_buffer_l = 0 | |
| out = torch.cat(reprs, dim=1) | |
| if encoding_window == "full_series": | |
| out = F.max_pool1d( | |
| out.transpose(1, 2).contiguous(), | |
| kernel_size=out.size(1), | |
| ).squeeze(1) | |
| else: | |
| out = self._eval_with_pooling( | |
| x, mask, encoding_window=encoding_window | |
| ) | |
| if encoding_window == "full_series": | |
| out = out.squeeze(1) | |
| output.append(out) | |
| output = torch.cat(output, dim=0) | |
| self.net.train(org_training) | |
| return output.numpy() | |
| def save(self, fn): | |
| """Save the model to a file. | |
| Args: | |
| fn (str): filename. | |
| """ | |
| torch.save(self.net.state_dict(), fn) | |
| def load(self, fn): | |
| """Load the model from a file. | |
| Args: | |
| fn (str): filename. | |
| """ | |
| state_dict = torch.load(fn, map_location=self.device) | |
| self.net.load_state_dict(state_dict) | |