File size: 1,856 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import random

import numpy as np
import torch

from easy_tpp.utils.import_utils import is_torch_mps_available


def set_seed(seed=1029):
    """Setup random seed.

    Args:
        seed (int, optional): random seed. Defaults to 1029.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def set_device(gpu=-1):
    """Setup the device.

    Args:
        gpu (int, optional): num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU).
    """
    if gpu >= 0:
        if torch.cuda.is_available():
            device = torch.device("cuda:" + str(gpu))
        elif is_torch_mps_available():
            device = torch.device("mps")
    else:
        device = torch.device("cpu")
    return device


def set_optimizer(optimizer, params, lr):
    """Setup the optimizer.

    Args:
        optimizer (str): name of the optimizer.
        params (dict): dict of params for the optimizer.
        lr (float): learning rate.

    Raises:
        NotImplementedError: if the optimizer's name is wrong or the optimizer is not supported,
        we raise error.

    Returns:
        torch.optim: torch optimizer.
    """
    if isinstance(optimizer, str):
        if optimizer.lower() == "adam":
            optimizer = "Adam"
    try:
        optimizer = getattr(torch.optim, optimizer)(params, lr=lr)
    except Exception:
        raise NotImplementedError("optimizer={} is not supported.".format(optimizer))
    return optimizer


def count_model_params(model):
    """Count the number of params of the model.

    Args:
        model (torch.nn.Moduel): a torch model.

    Returns:
        int: total num of the parameters.
    """
    return sum(p.numel() for p in model.parameters())