import numpy as np import hypothesis.strategies as st import unittest import caffe2.python.hypothesis_test_util as hu from caffe2.python import core, workspace from hypothesis import given import caffe2.python.ideep_test_util as mu @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") class TestMomentumSGDUpdateOps(hu.HypothesisTestCase): @given(n=st.integers(4, 8), nesterov=st.booleans(), **mu.gcs) def test_MomentumSGDUpdate(self, n, nesterov, gc, dc): param = np.random.rand(n).astype(np.float32) grad = np.random.rand(n).astype(np.float32) lr = np.random.rand(1).astype(np.float32) param_momentum = np.random.rand(n).astype(np.float32) momentum = 0.9 op = core.CreateOperator( "MomentumSGDUpdate", ["grad", "param_momentum", "lr", "param"], ["grad", "param_momentum", "param"], momentum=momentum, nesterov=int(nesterov), ) # Iter lives on the CPU input_device_options = {'lr': hu.cpu_do} self.assertDeviceChecks( dc, op, [grad, param_momentum, lr, param], [0], input_device_options=input_device_options, threshold=0.001) op_noparam = core.CreateOperator( "MomentumSGD", ["grad", "param_momentum", "lr"], ["grad", "param_momentum"], momentum=momentum, nesterov=int(nesterov), ) self.assertDeviceChecks( dc, op_noparam, [grad, param_momentum, lr], [0], input_device_options=input_device_options, threshold=0.001) if __name__ == "__main__": unittest.main()