CSU-MS2-T2 / nn_utils /nn_utils.py
Tingxie's picture
Upload 10 files
c8bfe50
""" nn_utils.py
"""
import math
import copy
import torch
import torch.nn as nn
def build_lr_scheduler(
optimizer, lr_decay_rate: float, decay_steps: int = 5000, warmup: int = 100
):
"""build_lr_scheduler.
Args:
optimizer:
lr_decay_rate (float): lr_decay_rate
decay_steps (int): decay_steps
warmup_steps (int): warmup_steps
"""
def lr_lambda(step):
if step >= warmup:
# Adjust
step = step - warmup
rate = lr_decay_rate ** (step // decay_steps)
else:
rate = 1 - math.exp(-step / warmup)
return rate
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
return scheduler
class MLPBlocks(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
dropout: float,
num_layers: int,
):
super().__init__()
self.activation = nn.ReLU()
self.dropout_layer = nn.Dropout(p=dropout)
self.input_layer = nn.Linear(input_size, hidden_size)
middle_layer = nn.Linear(hidden_size, hidden_size)
self.layers = get_clones(middle_layer, num_layers - 1)
def forward(self, x):
output = x
output = self.input_layer(x)
output = self.dropout_layer(output)
output = self.activation(output)
old_output = output
for layer_index, layer in enumerate(self.layers):
output = layer(output)
output = self.dropout_layer(output)
output = self.activation(output) + old_output
old_output = output
return output
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def pad_packed_tensor(input, lengths, value):
"""pad_packed_tensor"""
old_shape = input.shape
device = input.device
if not isinstance(lengths, torch.Tensor):
lengths = torch.tensor(lengths, dtype=torch.int64, device=device)
else:
lengths = lengths.to(device)
max_len = (lengths.max()).item()
batch_size = len(lengths)
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
# Initialize a tensor with an index for every value in the array
index = torch.ones(len(input), dtype=torch.int64, device=device)
# Row shifts
row_shifts = torch.cumsum(max_len - lengths, 0)
# Calculate shifts for second row, third row... nth row (not the n+1th row)
# Expand this out to match the shape of all entries after the first row
row_shifts_expanded = row_shifts[:-1].repeat_interleave(lengths[1:])
# Add this to the list of inds _after_ the first row
cumsum_inds = torch.cumsum(index, 0) - 1
cumsum_inds[lengths[0] :] += row_shifts_expanded
x[cumsum_inds] = input
return x.view(batch_size, max_len, *old_shape[1:])