import unittest import numpy as np import copy from hypothesis import given import hypothesis.strategies as st from caffe2.python.model_helper import ModelHelper from caffe2.python.models import resnet from caffe2.python import workspace, brew import caffe2.python.hypothesis_test_util as hu import caffe2.python.mkl.rewrite_graph as rewrite_graph def deterministic_io(model): model = copy.deepcopy(model) for i, op in enumerate(model.InitProto().op): op.device_option.random_seed = i + 1 if not model.Proto().external_output: model.Proto().external_output.extend([model.Proto().op[-1].output[0]]) return model def simple_fc(): model = ModelHelper(name="r") brew.fc(model, "data", "fc", 10, 10) return model, [(1, 10)] def double_matmul(): model = ModelHelper(name="r") fc0 = brew.fc(model, "data", "fc0", 10, 10) fc1 = brew.fc(model, fc0, "fc1", 10, 10) model.Proto().external_output[:] = [str(fc0), str(fc1)] return model, [(1, 10)] def simple_relu(): model = ModelHelper(name="r") brew.relu(model, "data", "fc") return model, [(1, 10)] def simple_mlp(): model = ModelHelper(name="r") brew.relu( model, brew.fc( model, brew.relu( model, brew.fc( model, "data", "fc1", 10, 10), "rl1"), "fc2", 10, 10), "rl2") return model, [(1, 10)] def simple_cnn(): model = ModelHelper(name="r", arg_scope={"order": "NCHW", "is_test": True}) brew.conv( model, "data", 'conv1', 3, 16, kernel=3, stride=1 ) brew.spatial_bn( model, 'conv1', 'conv1_spatbn', 16, epsilon=1e-3 ) brew.relu(model, 'conv1_spatbn', 'relu1') return model, [(1, 3, 32, 32)] def alexnet(): model = ModelHelper(name="r", arg_scope={"order": "NCHW", "is_test": True}) conv1 = brew.conv( model, "data", "conv1", 3, 64, 11, ('XavierFill', {}), ('ConstantFill', {}), stride=4, pad=2 ) relu1 = brew.relu(model, conv1, "conv1") pool1 = brew.max_pool(model, relu1, "pool1", kernel=3, stride=2, pad=0, legacy_pad=3) lrn1 = brew.lrn( model, pool1, "pool1_lrn", size=5, alpha=1.0e-4, beta=0.75, bias=1.0) conv2 = brew.conv( model, lrn1, "conv2", 64, 192, 5, ('XavierFill', {}), ('ConstantFill', {}), pad=2 ) relu2 = brew.relu(model, conv2, "conv2") pool2 = brew.max_pool(model, relu2, "pool2", kernel=3, stride=2) lrn2 = brew.lrn( model, pool2, "pool2_lrn", size=5, alpha=1.0e-4, beta=0.75, bias=1.0) conv3 = brew.conv( model, lrn2, "conv3", 192, 384, 3, ('XavierFill', {}), ('ConstantFill', {}), pad=1 ) relu3 = brew.relu(model, conv3, "conv3") conv4 = brew.conv( model, relu3, "conv4", 384, 256, 3, ('XavierFill', {}), ('ConstantFill', {}), pad=1 ) relu4 = brew.relu(model, conv4, "conv4") conv5 = brew.conv( model, relu4, "conv5", 256, 256, 3, ('XavierFill', {}), ('ConstantFill', {}), pad=1 ) relu5 = brew.relu(model, conv5, "conv5") pool5 = brew.max_pool(model, relu5, "pool5", kernel=3, stride=2) fc6 = brew.fc( model, pool5, "fc6", 256 * 6 * 6, 4096, ('XavierFill', {}), ('ConstantFill', {}) ) relu6 = brew.relu(model, fc6, "fc6") fc7 = brew.fc( model, relu6, "fc7", 4096, 4096, ('XavierFill', {}), ('ConstantFill', {}) ) relu7 = brew.relu(model, fc7, "fc7") drop7 = brew.dropout(model, relu7, "fc7_dropout", is_test=1, ratio=0.5) fc8 = brew.fc( model, drop7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {}) ) relu8 = brew.relu(model, fc8, "fc8") brew.dropout(model, relu8, "fc8_dropout", is_test=1, ratio=0.5) return model, [(1, 3, 224, 224)] def simple_resnet(): model = ModelHelper(name="r", arg_scope={"order": "NCHW", "is_test": True}) resnet.create_resnet_32x32( model, "data", num_input_channels=1, num_groups=1, num_labels=5, is_test=True) return model, [(1, 1, 32, 32)] def complex_resnet(): model = ModelHelper(name="r", arg_scope={"order": "NCHW", "is_test": True}) resnet.create_resnet50( model, "data", num_input_channels=1, num_labels=5, is_test=True, no_loss=True) return model, [(1, 1, 224, 224)] @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") class MKLRewriteTest(hu.HypothesisTestCase): @given(gen=st.sampled_from([simple_relu, simple_fc, simple_mlp, simple_cnn])) def test_mkl_simple_rewrite(self, gen): cpu_model, (shape,) = gen() cpu_model = deterministic_io(cpu_model) mkl_model = rewrite_graph.rewrite_model_helper_simple(cpu_model) X = np.random.randn(*shape).astype(np.float32) def run(model): self.ws.run(model.InitProto()) self.ws.create_blob(model.Proto().external_input[0]).feed(X) self.ws.run(model.Proto()) return self.ws.blobs[model.Proto().external_output[0]].fetch() np.testing.assert_allclose(run(cpu_model), run(mkl_model), atol=1e-4, rtol=1e-4) def test_mkl_resnet_rewrite(self): cpu_model, (shape,) = complex_resnet() cpu_model = deterministic_io(cpu_model) mkl_model = rewrite_graph.rewrite_model_helper_simple(cpu_model) np.random.seed(1701) X = np.random.randn(*shape).astype(np.float32) def run(model): self.ws.run(model.InitProto()) self.ws.create_blob(model.Proto().external_input[0]).feed(X) self.ws.run(model.Proto()) return self.ws.blobs[model.Proto().external_output[0]].fetch() np.testing.assert_allclose(run(cpu_model), run(mkl_model), atol=1e-4, rtol=1e-4) def test_mkl_multi_output_rewrite(self): cpu_model, shapes = double_matmul() cpu_model = deterministic_io(cpu_model) mkl_model = rewrite_graph.rewrite_model_helper_simple(cpu_model) np.random.seed(1701) Xs = [np.random.randn(*shape).astype(np.float32) for shape in shapes] def run(model): self.ws.run(model.InitProto()) for (name, X) in zip(model.Proto().external_input, Xs): self.ws.create_blob(name).feed(X) print(model.Proto()) self.ws.run(model.Proto()) return [self.ws.blobs[name].fetch() for name in model.Proto().external_output] run(mkl_model) np.testing.assert_allclose(run(cpu_model), run(mkl_model), atol=1e-4, rtol=1e-4) def test_mkl_alexnet_rewrite(self): cpu_model, (shape,) = alexnet() cpu_model = deterministic_io(cpu_model) mkl_model = rewrite_graph.rewrite_model_helper_simple(cpu_model) np.random.seed(1701) X = np.random.randn(*shape).astype(np.float32) def run(model): self.ws.run(model.InitProto()) self.ws.create_blob(model.Proto().external_input[0]).feed(X) self.ws.run(model.Proto()) return self.ws.blobs[model.Proto().external_output[0]].fetch() np.testing.assert_allclose(run(cpu_model), run(mkl_model), atol=1e-4, rtol=1e-4) if __name__ == "__main__": import unittest unittest.main()