Spaces:
Sleeping
Sleeping
File size: 3,027 Bytes
66c9c8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import warp as wp
from typing import Any
@wp.kernel
def sgd_step_kernel(
g: wp.array(dtype=Any),
b: wp.array(dtype=Any),
lr: float,
weight_decay: float,
momentum: float,
damping: float,
nesterov: int,
t: int,
params: wp.array(dtype=Any),
):
i = wp.tid()
gt = g[i]
if weight_decay != 0.0:
gt += weight_decay * params[i]
if momentum != 0.0:
bt = b[i]
if t > 0:
bt = momentum * bt + (1.0 - damping) * gt
else:
bt = gt
if nesterov == 1:
gt += momentum * bt
else:
gt = bt
b[i] = bt
params[i] = params[i] - lr * gt
class SGD:
"""An implementation of the Stochastic Gradient Descent Optimizer
It is designed to mimic Pytorch's version.
https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
"""
def __init__(self, params=None, lr=0.001, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False):
self.b = [] # momentum buffer
self.set_params(params)
self.lr = lr
self.momentum = momentum
self.dampening = dampening
self.weight_decay = weight_decay
self.nesterov = nesterov
self.t = 0
def set_params(self, params):
self.params = params
if params is not None and type(params) == list and len(params) > 0:
if len(self.b) != len(params):
self.b = [None] * len(params)
for i in range(len(params)):
param = params[i]
if self.b[i] is None or self.b[i].shape != param.shape or self.b[i].dtype != param.dtype:
self.b[i] = wp.zeros_like(param)
def reset_internal_state(self):
for b_i in self.b:
b_i.zero_()
self.t = 0
def step(self, grad):
assert self.params is not None
for i in range(len(self.params)):
SGD.step_detail(
grad[i], self.b[i], self.lr, self.momentum, self.dampening, self.weight_decay, self.nesterov, self.t, self.params[i]
)
self.t = self.t + 1
@staticmethod
def step_detail(g, b, lr, momentum, dampening, weight_decay, nesterov, t, params):
assert params.dtype == g.dtype
assert params.dtype == b.dtype
assert params.shape == g.shape
kernel_inputs = [g, b, lr, momentum, dampening, weight_decay, int(nesterov), t, params]
wp.launch(
kernel=sgd_step_kernel,
dim=len(params),
inputs=kernel_inputs,
device=params.device,
)
|