qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import warp as wp
@wp.kernel
def adam_step_kernel_vec3(
g: wp.array(dtype=wp.vec3),
m: wp.array(dtype=wp.vec3),
v: wp.array(dtype=wp.vec3),
lr: float,
beta1: float,
beta2: float,
t: float,
eps: float,
params: wp.array(dtype=wp.vec3),
):
i = wp.tid()
m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
v[i] = beta2 * v[i] + (1.0 - beta2) * wp.cw_mul(g[i], g[i])
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
sqrt_vhat = wp.vec3(wp.sqrt(vhat[0]), wp.sqrt(vhat[1]), wp.sqrt(vhat[2]))
eps_vec3 = wp.vec3(eps, eps, eps)
params[i] = params[i] - lr * wp.cw_div(mhat, (sqrt_vhat + eps_vec3))
@wp.kernel
def adam_step_kernel_float(
g: wp.array(dtype=float),
m: wp.array(dtype=float),
v: wp.array(dtype=float),
lr: float,
beta1: float,
beta2: float,
t: float,
eps: float,
params: wp.array(dtype=float),
):
i = wp.tid()
m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i]
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps)
class Adam:
"""An implementation of the Adam Optimizer
It is designed to mimic Pytorch's version.
https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
"""
def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08):
self.m = [] # first moment
self.v = [] # second moment
self.set_params(params)
self.lr = lr
self.beta1 = betas[0]
self.beta2 = betas[1]
self.eps = eps
self.t = 0
def set_params(self, params):
self.params = params
if params != None and type(params) == list and len(params) > 0:
if len(self.m) != len(params):
self.m = [None] * len(params) # reset first moment
if len(self.v) != len(params):
self.v = [None] * len(params) # reset second moment
for i in range(len(params)):
param = params[i]
if self.m[i] == None or self.m[i].shape != param.shape or self.m[i].dtype != param.dtype:
self.m[i] = wp.zeros_like(param)
if self.v[i] == None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
self.v[i] = wp.zeros_like(param)
def reset_internal_state(self):
for m_i in self.m:
m_i.zero_()
for v_i in self.v:
v_i.zero_()
self.t = 0
def step(self, grad):
assert self.params != None
for i in range(len(self.params)):
Adam.step_detail(
grad[i], self.m[i], self.v[i], self.lr, self.beta1, self.beta2, self.t, self.eps, self.params[i]
)
self.t = self.t + 1
@staticmethod
def step_detail(g, m, v, lr, beta1, beta2, t, eps, params):
assert params.dtype == g.dtype
assert params.dtype == m.dtype
assert params.dtype == v.dtype
assert params.shape == g.shape
kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params]
if params.dtype == wp.types.float32:
wp.launch(
kernel=adam_step_kernel_float,
dim=len(params),
inputs=kernel_inputs,
device=params.device,
)
elif params.dtype == wp.types.vec3:
wp.launch(
kernel=adam_step_kernel_vec3,
dim=len(params),
inputs=kernel_inputs,
device=params.device,
)
else:
raise RuntimeError("Params data type not supported in Adam step kernels.")