## @package store_ops_test_util # Module caffe2.distributed.store_ops_test_util from multiprocessing import Process, Queue import numpy as np from caffe2.python import core, workspace class StoreOpsTests(object): @classmethod def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs): store_handler = create_store_handler_fn() blob = "blob" value = np.full(1, 1, np.float32) # Use last process to set blob to make sure other processes # are waiting for the blob before it is set. if index == (num_procs - 1): workspace.FeedBlob(blob, value) workspace.RunOperatorOnce( core.CreateOperator( "StoreSet", [store_handler, blob], [], blob_name=blob)) output_blob = "output_blob" workspace.RunOperatorOnce( core.CreateOperator( "StoreGet", [store_handler], [output_blob], blob_name=blob)) try: np.testing.assert_array_equal(workspace.FetchBlob(output_blob), 1) except AssertionError as err: queue.put(err) workspace.ResetWorkspace() @classmethod def test_set_get(cls, create_store_handler_fn): # Queue for assertion errors on subprocesses queue = Queue() # Start N processes in the background num_procs = 4 procs = [] for index in range(num_procs): proc = Process( target=cls._test_set_get, args=(queue, create_store_handler_fn, index, num_procs, )) proc.start() procs.append(proc) # Test complete, join background processes for proc in procs: proc.join() # Raise first error we find, if any if not queue.empty(): raise queue.get() @classmethod def test_get_timeout(cls, create_store_handler_fn): store_handler = create_store_handler_fn() net = core.Net('get_missing_blob') net.StoreGet([store_handler], 1, blob_name='blob') workspace.RunNetOnce(net)