File size: 4,239 Bytes
66c9c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2022 NVIDIA CORPORATION.  All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import warp as wp


@wp.kernel
def adam_step_kernel_vec3(
    g: wp.array(dtype=wp.vec3),
    m: wp.array(dtype=wp.vec3),
    v: wp.array(dtype=wp.vec3),
    lr: float,
    beta1: float,
    beta2: float,
    t: float,
    eps: float,
    params: wp.array(dtype=wp.vec3),
):
    i = wp.tid()
    m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
    v[i] = beta2 * v[i] + (1.0 - beta2) * wp.cw_mul(g[i], g[i])
    mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
    vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
    sqrt_vhat = wp.vec3(wp.sqrt(vhat[0]), wp.sqrt(vhat[1]), wp.sqrt(vhat[2]))
    eps_vec3 = wp.vec3(eps, eps, eps)
    params[i] = params[i] - lr * wp.cw_div(mhat, (sqrt_vhat + eps_vec3))


@wp.kernel
def adam_step_kernel_float(
    g: wp.array(dtype=float),
    m: wp.array(dtype=float),
    v: wp.array(dtype=float),
    lr: float,
    beta1: float,
    beta2: float,
    t: float,
    eps: float,
    params: wp.array(dtype=float),
):
    i = wp.tid()
    m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
    v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i]
    mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
    vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
    params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps)


class Adam:
    """An implementation of the Adam Optimizer
    It is designed to mimic Pytorch's version.
    https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
    """

    def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08):
        self.m = []  # first moment
        self.v = []  # second moment
        self.set_params(params)
        self.lr = lr
        self.beta1 = betas[0]
        self.beta2 = betas[1]
        self.eps = eps
        self.t = 0

    def set_params(self, params):
        self.params = params
        if params != None and type(params) == list and len(params) > 0:
            if len(self.m) != len(params):
                self.m = [None] * len(params)  # reset first moment
            if len(self.v) != len(params):
                self.v = [None] * len(params)  # reset second moment
            for i in range(len(params)):
                param = params[i]
                if self.m[i] == None or self.m[i].shape != param.shape or self.m[i].dtype != param.dtype:
                    self.m[i] = wp.zeros_like(param)
                if self.v[i] == None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
                    self.v[i] = wp.zeros_like(param)

    def reset_internal_state(self):
        for m_i in self.m:
            m_i.zero_()
        for v_i in self.v:
            v_i.zero_()
        self.t = 0

    def step(self, grad):
        assert self.params != None
        for i in range(len(self.params)):
            Adam.step_detail(
                grad[i], self.m[i], self.v[i], self.lr, self.beta1, self.beta2, self.t, self.eps, self.params[i]
            )
        self.t = self.t + 1

    @staticmethod
    def step_detail(g, m, v, lr, beta1, beta2, t, eps, params):
        assert params.dtype == g.dtype
        assert params.dtype == m.dtype
        assert params.dtype == v.dtype
        assert params.shape == g.shape
        kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params]
        if params.dtype == wp.types.float32:
            wp.launch(
                kernel=adam_step_kernel_float,
                dim=len(params),
                inputs=kernel_inputs,
                device=params.device,
            )
        elif params.dtype == wp.types.vec3:
            wp.launch(
                kernel=adam_step_kernel_vec3,
                dim=len(params),
                inputs=kernel_inputs,
                device=params.device,
            )
        else:
            raise RuntimeError("Params data type not supported in Adam step kernels.")