Spaces:
Sleeping
Sleeping
File size: 2,860 Bytes
c8bfe50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
""" nn_utils.py
"""
import math
import copy
import torch
import torch.nn as nn
def build_lr_scheduler(
optimizer, lr_decay_rate: float, decay_steps: int = 5000, warmup: int = 100
):
"""build_lr_scheduler.
Args:
optimizer:
lr_decay_rate (float): lr_decay_rate
decay_steps (int): decay_steps
warmup_steps (int): warmup_steps
"""
def lr_lambda(step):
if step >= warmup:
# Adjust
step = step - warmup
rate = lr_decay_rate ** (step // decay_steps)
else:
rate = 1 - math.exp(-step / warmup)
return rate
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
return scheduler
class MLPBlocks(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
dropout: float,
num_layers: int,
):
super().__init__()
self.activation = nn.ReLU()
self.dropout_layer = nn.Dropout(p=dropout)
self.input_layer = nn.Linear(input_size, hidden_size)
middle_layer = nn.Linear(hidden_size, hidden_size)
self.layers = get_clones(middle_layer, num_layers - 1)
def forward(self, x):
output = x
output = self.input_layer(x)
output = self.dropout_layer(output)
output = self.activation(output)
old_output = output
for layer_index, layer in enumerate(self.layers):
output = layer(output)
output = self.dropout_layer(output)
output = self.activation(output) + old_output
old_output = output
return output
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def pad_packed_tensor(input, lengths, value):
"""pad_packed_tensor"""
old_shape = input.shape
device = input.device
if not isinstance(lengths, torch.Tensor):
lengths = torch.tensor(lengths, dtype=torch.int64, device=device)
else:
lengths = lengths.to(device)
max_len = (lengths.max()).item()
batch_size = len(lengths)
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
# Initialize a tensor with an index for every value in the array
index = torch.ones(len(input), dtype=torch.int64, device=device)
# Row shifts
row_shifts = torch.cumsum(max_len - lengths, 0)
# Calculate shifts for second row, third row... nth row (not the n+1th row)
# Expand this out to match the shape of all entries after the first row
row_shifts_expanded = row_shifts[:-1].repeat_interleave(lengths[1:])
# Add this to the list of inds _after_ the first row
cumsum_inds = torch.cumsum(index, 0) - 1
cumsum_inds[lengths[0] :] += row_shifts_expanded
x[cumsum_inds] = input
return x.view(batch_size, max_len, *old_shape[1:])
|