SATA / src /sata /utils /network_utils.py
zzysteve
Initial commit
5221c8c
Raw
History Blame Contribute Delete
4.09 kB
from torch import nn
import torch, copy
import numpy as np
def mlp(
sizes,
activation,
output_activation=nn.Identity(),
use_batchnorm=False,
dropout_p=None,
):
layers = []
for j in range(len(sizes) - 1):
act = activation if j < len(sizes) - 2 else output_activation
layers += [nn.Linear(sizes[j], sizes[j + 1])]
if use_batchnorm:
layers += nn.BatchNorm1d(sizes[j + 1])
layers += [act]
if dropout_p is not None:
layers += [nn.Dropout(p=dropout_p)]
# if use_batchnorm:
# layers += [nn.Linear(sizes[j], sizes[j+1]), nn.BatchNorm1d(sizes[j+1]), act]
# else: layers += [nn.Linear(sizes[j], sizes[j+1]), act]
return nn.Sequential(*layers)
def get_scheduler(optimizer, scheduler_dict=None, epoch_num=None) -> torch.optim.lr_scheduler.LRScheduler:
if scheduler_dict is None:
raise ValueError("No scheduler_dict provided!")
scheduler_type = scheduler_dict["type"]
scheduler_dict.pop("type")
if scheduler_type == "exponential":
def lambda_rule(epoch):
lr_l = max(scheduler_dict["min"], scheduler_dict["gamma"] ** epoch)
return lr_l
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
# return torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_dict)
elif scheduler_type == "linear":
assert epoch_num != None, "linear schedule but epoch is None"
def lambda_rule(epoch):
lr_l = max(
scheduler_dict["min"], 1 - epoch * scheduler_dict["slope"]
) # 1/float(epoch_num+1)
return lr_l
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif scheduler_type == "Step_LR":
print("Step_LR scheduler set")
return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_dict)
elif scheduler_type == "Step_LR_with_warmup":
print("Step_LR_with_warmup scheduler set")
warmup_epochs = scheduler_dict.pop("warmup_epochs", 5)
# warm up phase
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
)
# StepLR
step_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_dict)
# combine
return torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_scheduler, step_scheduler], milestones=[warmup_epochs]
)
elif scheduler_type == "exponential_with_warmup":
print("exponential_with_warmup scheduler set")
warmup_epochs = scheduler_dict.pop("warmup_epochs", 5)
# warm up phase
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
)
# exponential decay
def lambda_rule(epoch):
lr_l = max(scheduler_dict["min"], scheduler_dict["gamma"] ** epoch)
return lr_l
exponential_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
# combine
return torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_scheduler, exponential_scheduler], milestones=[warmup_epochs]
)
else: raise NotImplementedError(f"Scheduler {scheduler_type} is not implemented!")
##################################################################################
# return torch.optim.lr_scheduler.StepLR(optimizer, 600, 0.5)
# if scheduler_type == 'Plateau':
# print('Plateau_LR shceduler set')
# return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_dict)
# return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5, verbose=True)
def count_param(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
num_params = sum([np.prod(p.size()) for p in model_parameters])
return num_params