Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import warp as wp | |
| def adam_step_kernel_vec3( | |
| g: wp.array(dtype=wp.vec3), | |
| m: wp.array(dtype=wp.vec3), | |
| v: wp.array(dtype=wp.vec3), | |
| lr: float, | |
| beta1: float, | |
| beta2: float, | |
| t: float, | |
| eps: float, | |
| params: wp.array(dtype=wp.vec3), | |
| ): | |
| i = wp.tid() | |
| m[i] = beta1 * m[i] + (1.0 - beta1) * g[i] | |
| v[i] = beta2 * v[i] + (1.0 - beta2) * wp.cw_mul(g[i], g[i]) | |
| mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0))) | |
| vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0))) | |
| sqrt_vhat = wp.vec3(wp.sqrt(vhat[0]), wp.sqrt(vhat[1]), wp.sqrt(vhat[2])) | |
| eps_vec3 = wp.vec3(eps, eps, eps) | |
| params[i] = params[i] - lr * wp.cw_div(mhat, (sqrt_vhat + eps_vec3)) | |
| def adam_step_kernel_float( | |
| g: wp.array(dtype=float), | |
| m: wp.array(dtype=float), | |
| v: wp.array(dtype=float), | |
| lr: float, | |
| beta1: float, | |
| beta2: float, | |
| t: float, | |
| eps: float, | |
| params: wp.array(dtype=float), | |
| ): | |
| i = wp.tid() | |
| m[i] = beta1 * m[i] + (1.0 - beta1) * g[i] | |
| v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i] | |
| mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0))) | |
| vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0))) | |
| params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps) | |
| class Adam: | |
| """An implementation of the Adam Optimizer | |
| It is designed to mimic Pytorch's version. | |
| https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam | |
| """ | |
| def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08): | |
| self.m = [] # first moment | |
| self.v = [] # second moment | |
| self.set_params(params) | |
| self.lr = lr | |
| self.beta1 = betas[0] | |
| self.beta2 = betas[1] | |
| self.eps = eps | |
| self.t = 0 | |
| def set_params(self, params): | |
| self.params = params | |
| if params != None and type(params) == list and len(params) > 0: | |
| if len(self.m) != len(params): | |
| self.m = [None] * len(params) # reset first moment | |
| if len(self.v) != len(params): | |
| self.v = [None] * len(params) # reset second moment | |
| for i in range(len(params)): | |
| param = params[i] | |
| if self.m[i] == None or self.m[i].shape != param.shape or self.m[i].dtype != param.dtype: | |
| self.m[i] = wp.zeros_like(param) | |
| if self.v[i] == None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype: | |
| self.v[i] = wp.zeros_like(param) | |
| def reset_internal_state(self): | |
| for m_i in self.m: | |
| m_i.zero_() | |
| for v_i in self.v: | |
| v_i.zero_() | |
| self.t = 0 | |
| def step(self, grad): | |
| assert self.params != None | |
| for i in range(len(self.params)): | |
| Adam.step_detail( | |
| grad[i], self.m[i], self.v[i], self.lr, self.beta1, self.beta2, self.t, self.eps, self.params[i] | |
| ) | |
| self.t = self.t + 1 | |
| def step_detail(g, m, v, lr, beta1, beta2, t, eps, params): | |
| assert params.dtype == g.dtype | |
| assert params.dtype == m.dtype | |
| assert params.dtype == v.dtype | |
| assert params.shape == g.shape | |
| kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params] | |
| if params.dtype == wp.types.float32: | |
| wp.launch( | |
| kernel=adam_step_kernel_float, | |
| dim=len(params), | |
| inputs=kernel_inputs, | |
| device=params.device, | |
| ) | |
| elif params.dtype == wp.types.vec3: | |
| wp.launch( | |
| kernel=adam_step_kernel_vec3, | |
| dim=len(params), | |
| inputs=kernel_inputs, | |
| device=params.device, | |
| ) | |
| else: | |
| raise RuntimeError("Params data type not supported in Adam step kernels.") | |