Spaces:
Sleeping
Sleeping
| """ nn_utils.py | |
| """ | |
| import math | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| def build_lr_scheduler( | |
| optimizer, lr_decay_rate: float, decay_steps: int = 5000, warmup: int = 100 | |
| ): | |
| """build_lr_scheduler. | |
| Args: | |
| optimizer: | |
| lr_decay_rate (float): lr_decay_rate | |
| decay_steps (int): decay_steps | |
| warmup_steps (int): warmup_steps | |
| """ | |
| def lr_lambda(step): | |
| if step >= warmup: | |
| # Adjust | |
| step = step - warmup | |
| rate = lr_decay_rate ** (step // decay_steps) | |
| else: | |
| rate = 1 - math.exp(-step / warmup) | |
| return rate | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
| return scheduler | |
| class MLPBlocks(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| dropout: float, | |
| num_layers: int, | |
| ): | |
| super().__init__() | |
| self.activation = nn.ReLU() | |
| self.dropout_layer = nn.Dropout(p=dropout) | |
| self.input_layer = nn.Linear(input_size, hidden_size) | |
| middle_layer = nn.Linear(hidden_size, hidden_size) | |
| self.layers = get_clones(middle_layer, num_layers - 1) | |
| def forward(self, x): | |
| output = x | |
| output = self.input_layer(x) | |
| output = self.dropout_layer(output) | |
| output = self.activation(output) | |
| old_output = output | |
| for layer_index, layer in enumerate(self.layers): | |
| output = layer(output) | |
| output = self.dropout_layer(output) | |
| output = self.activation(output) + old_output | |
| old_output = output | |
| return output | |
| def get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def pad_packed_tensor(input, lengths, value): | |
| """pad_packed_tensor""" | |
| old_shape = input.shape | |
| device = input.device | |
| if not isinstance(lengths, torch.Tensor): | |
| lengths = torch.tensor(lengths, dtype=torch.int64, device=device) | |
| else: | |
| lengths = lengths.to(device) | |
| max_len = (lengths.max()).item() | |
| batch_size = len(lengths) | |
| x = input.new(batch_size * max_len, *old_shape[1:]) | |
| x.fill_(value) | |
| # Initialize a tensor with an index for every value in the array | |
| index = torch.ones(len(input), dtype=torch.int64, device=device) | |
| # Row shifts | |
| row_shifts = torch.cumsum(max_len - lengths, 0) | |
| # Calculate shifts for second row, third row... nth row (not the n+1th row) | |
| # Expand this out to match the shape of all entries after the first row | |
| row_shifts_expanded = row_shifts[:-1].repeat_interleave(lengths[1:]) | |
| # Add this to the list of inds _after_ the first row | |
| cumsum_inds = torch.cumsum(index, 0) - 1 | |
| cumsum_inds[lengths[0] :] += row_shifts_expanded | |
| x[cumsum_inds] = input | |
| return x.view(batch_size, max_len, *old_shape[1:]) | |