slahmr-test / lib /python3.9 /site-packages /caffe2 /python /serialized_test /serialized_test_util.py
| import inspect | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import threading | |
| from contextlib import contextmanager | |
| from zipfile import ZipFile | |
| import argparse | |
| import hypothesis as hy | |
| import numpy as np | |
| import caffe2.python.hypothesis_test_util as hu | |
| from caffe2.proto import caffe2_pb2 | |
| from caffe2.python import gradient_checker | |
| from caffe2.python.serialized_test import coverage | |
| operator_test_type = 'operator_test' | |
| TOP_DIR = os.path.dirname(os.path.realpath(__file__)) | |
| DATA_SUFFIX = 'data' | |
| DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX) | |
| _output_context = threading.local() | |
| def given(*given_args, **given_kwargs): | |
| def wrapper(f): | |
| hyp_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(*given_args, **given_kwargs)(f))) | |
| fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given( | |
| *given_args, **given_kwargs)(f))) | |
| def func(self, *args, **kwargs): | |
| self.should_serialize = True | |
| fixed_seed_func(self, *args, **kwargs) | |
| self.should_serialize = False | |
| hyp_func(self, *args, **kwargs) | |
| return func | |
| return wrapper | |
| def _getGradientOrNone(op_proto): | |
| try: | |
| grad_ops, _ = gradient_checker.getGradientForOp(op_proto) | |
| return grad_ops | |
| except Exception: | |
| return [] | |
| # necessary to support converting jagged lists into numpy arrays | |
| def _transformList(l): | |
| ret = np.empty(len(l), dtype=np.object) | |
| for (i, arr) in enumerate(l): | |
| ret[i] = arr | |
| return ret | |
| def _prepare_dir(path): | |
| if os.path.exists(path): | |
| shutil.rmtree(path) | |
| os.makedirs(path) | |
| class SerializedTestCase(hu.HypothesisTestCase): | |
| should_serialize = False | |
| def get_output_dir(self): | |
| output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR) | |
| output_dir = os.path.join( | |
| output_dir_arg, operator_test_type) | |
| if os.path.exists(output_dir): | |
| return output_dir | |
| # fall back to pwd | |
| cwd = os.getcwd() | |
| serialized_util_module_components = __name__.split('.') | |
| serialized_util_module_components.pop() | |
| serialized_dir = '/'.join(serialized_util_module_components) | |
| output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX) | |
| output_dir = os.path.join( | |
| output_dir_fallback, | |
| operator_test_type) | |
| return output_dir | |
| def get_output_filename(self): | |
| class_path = inspect.getfile(self.__class__) | |
| file_name_components = os.path.basename(class_path).split('.') | |
| test_file = file_name_components[0] | |
| function_name_components = self.id().split('.') | |
| test_function = function_name_components[-1] | |
| return test_file + '.' + test_function | |
| def serialize_test(self, inputs, outputs, grad_ops, op, device_option): | |
| output_dir = self.get_output_dir() | |
| test_name = self.get_output_filename() | |
| full_dir = os.path.join(output_dir, test_name) | |
| _prepare_dir(full_dir) | |
| inputs = _transformList(inputs) | |
| outputs = _transformList(outputs) | |
| device_type = int(device_option.device_type) | |
| op_path = os.path.join(full_dir, 'op.pb') | |
| grad_paths = [] | |
| inout_path = os.path.join(full_dir, 'inout') | |
| with open(op_path, 'wb') as f: | |
| f.write(op.SerializeToString()) | |
| for (i, grad) in enumerate(grad_ops): | |
| grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i)) | |
| grad_paths.append(grad_path) | |
| with open(grad_path, 'wb') as f: | |
| f.write(grad.SerializeToString()) | |
| np.savez_compressed( | |
| inout_path, | |
| inputs=inputs, | |
| outputs=outputs, | |
| device_type=device_type) | |
| with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z: | |
| z.write(op_path, 'op.pb') | |
| z.write(inout_path + '.npz', 'inout.npz') | |
| for path in grad_paths: | |
| z.write(path, os.path.basename(path)) | |
| shutil.rmtree(full_dir) | |
| def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7): | |
| def parse_proto(x): | |
| proto = caffe2_pb2.OperatorDef() | |
| proto.ParseFromString(x) | |
| return proto | |
| source_dir = self.get_output_dir() | |
| test_name = self.get_output_filename() | |
| temp_dir = tempfile.mkdtemp() | |
| with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z: | |
| z.extractall(temp_dir) | |
| op_path = os.path.join(temp_dir, 'op.pb') | |
| inout_path = os.path.join(temp_dir, 'inout.npz') | |
| # load serialized input and output | |
| loaded = np.load(inout_path, encoding='bytes', allow_pickle=True) | |
| loaded_inputs = loaded['inputs'].tolist() | |
| inputs_equal = True | |
| for (x, y) in zip(inputs, loaded_inputs): | |
| if not np.array_equal(x, y): | |
| inputs_equal = False | |
| loaded_outputs = loaded['outputs'].tolist() | |
| # if inputs are not the same, run serialized input through serialized op | |
| if not inputs_equal: | |
| # load operator | |
| with open(op_path, 'rb') as f: | |
| loaded_op = f.read() | |
| op_proto = parse_proto(loaded_op) | |
| device_type = loaded['device_type'] | |
| device_option = caffe2_pb2.DeviceOption( | |
| device_type=int(device_type)) | |
| outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs) | |
| grad_ops = _getGradientOrNone(op_proto) | |
| # assert outputs are equal | |
| for (x, y) in zip(outputs, loaded_outputs): | |
| np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) | |
| # assert gradient op is equal | |
| for i in range(len(grad_ops)): | |
| grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i)) | |
| with open(grad_path, 'rb') as f: | |
| loaded_grad = f.read() | |
| grad_proto = parse_proto(loaded_grad) | |
| self._assertSameOps(grad_proto, grad_ops[i]) | |
| shutil.rmtree(temp_dir) | |
| def _assertSameOps(self, op1, op2): | |
| op1_ = caffe2_pb2.OperatorDef() | |
| op1_.CopyFrom(op1) | |
| op1_.arg.sort(key=lambda arg: arg.name) | |
| op2_ = caffe2_pb2.OperatorDef() | |
| op2_.CopyFrom(op2) | |
| op2_.arg.sort(key=lambda arg: arg.name) | |
| self.assertEqual(op1_, op2_) | |
| def assertSerializedOperatorChecks( | |
| self, | |
| inputs, | |
| outputs, | |
| gradient_operator, | |
| op, | |
| device_option, | |
| atol=1e-7, | |
| rtol=1e-7, | |
| ): | |
| if self.should_serialize: | |
| if getattr(_output_context, 'should_generate_output', False): | |
| self.serialize_test( | |
| inputs, outputs, gradient_operator, op, device_option) | |
| if not getattr(_output_context, 'disable_gen_coverage', False): | |
| coverage.gen_serialized_test_coverage( | |
| self.get_output_dir(), TOP_DIR) | |
| else: | |
| self.compare_test( | |
| inputs, outputs, gradient_operator, atol, rtol) | |
| def assertReferenceChecks( | |
| self, | |
| device_option, | |
| op, | |
| inputs, | |
| reference, | |
| input_device_options=None, | |
| threshold=1e-4, | |
| output_to_grad=None, | |
| grad_reference=None, | |
| atol=None, | |
| outputs_to_check=None, | |
| ensure_outputs_are_inferred=False, | |
| ): | |
| outs = super(SerializedTestCase, self).assertReferenceChecks( | |
| device_option, | |
| op, | |
| inputs, | |
| reference, | |
| input_device_options, | |
| threshold, | |
| output_to_grad, | |
| grad_reference, | |
| atol, | |
| outputs_to_check, | |
| ensure_outputs_are_inferred, | |
| ) | |
| if not getattr(_output_context, 'disable_serialized_check', False): | |
| grad_ops = _getGradientOrNone(op) | |
| rtol = threshold | |
| if atol is None: | |
| atol = threshold | |
| self.assertSerializedOperatorChecks( | |
| inputs, | |
| outs, | |
| grad_ops, | |
| op, | |
| device_option, | |
| atol, | |
| rtol, | |
| ) | |
| def set_disable_serialized_check(self, val: bool): | |
| orig = getattr(_output_context, 'disable_serialized_check', False) | |
| try: | |
| # pyre-fixme[16]: `local` has no attribute `disable_serialized_check`. | |
| _output_context.disable_serialized_check = val | |
| yield | |
| finally: | |
| _output_context.disable_serialized_check = orig | |
| def testWithArgs(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '-G', '--generate-serialized', action='store_true', dest='generate', | |
| help='generate output files (default=false, compares to current files)') | |
| parser.add_argument( | |
| '-O', '--output', default=DATA_DIR, | |
| help='output directory (default: %(default)s)') | |
| parser.add_argument( | |
| '-D', '--disable-serialized_check', action='store_true', dest='disable', | |
| help='disable checking serialized tests') | |
| parser.add_argument( | |
| '-C', '--disable-gen-coverage', action='store_true', | |
| dest='disable_coverage', | |
| help='disable generating coverage markdown file') | |
| parser.add_argument('unittest_args', nargs='*') | |
| args = parser.parse_args() | |
| sys.argv[1:] = args.unittest_args | |
| _output_context.__setattr__('should_generate_output', args.generate) | |
| _output_context.__setattr__('output_dir', args.output) | |
| _output_context.__setattr__('disable_serialized_check', args.disable) | |
| _output_context.__setattr__('disable_gen_coverage', args.disable_coverage) | |
| import unittest | |
| unittest.main() | |