File size: 588 Bytes
a35137b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import numpy as np
import os
import random
import torch
import torch.nn as nn
def get_activation_function(activation_str):
if activation_str.lower() == "relu":
return nn.ReLU()
elif activation_str.lower() == "linear":
return lambda x: x
elif activation_str.lower() == "gelu":
return nn.GELU()
def seed_everything(seed):
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
print(f"Random seed set as {seed}") |