import argparse import copy import json import numpy as np from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace, utils import caffe2.python._import_c_extension as C def pairwise(iterable): from itertools import tee a, b = tee(iterable) next(b, None) return zip(a, b) def last_producer(ops, blob): for (i, op) in reversed(list(enumerate(ops))): if blob in op.output: return i raise ValueError("Failed to find last producer of blob, %s", blob) def blob_uses(net, blob): u = [] for i, op in enumerate(net.op): if blob in op.input or blob in op.control_input: u.append(i) return u def GetArgumentParser(): parser = argparse.ArgumentParser(description="Caffe2 optimization") parser.add_argument("--init_net", type=argparse.FileType('rb'), help="init net") parser.add_argument("--pred_net", type=argparse.FileType('rb'), help="predict net") parser.add_argument("--verify_input", type=argparse.FileType('r'), help="input dims for verification") parser.add_argument("--fuse_bn", default=False, action='store_true') parser.add_argument("--fuse_mul_add", default=False, action='store_true') parser.add_argument("--fuse_conv_relu", default=False, action='store_true') return parser def fuse_first_bn(net, params, removed_tensors): net = copy.deepcopy(net) params = copy.deepcopy(params) for ((i, current), (j, next_)) in pairwise(enumerate(net.op)): if next_.input[0] != current.output[0]: continue if current.type not in ("Conv", "ConvTranspose") \ or next_.type != "SpatialBN": continue if len(blob_uses(net, current.output[0])) != 1: # Can't fuse if more than one user continue # else, can fuse conv = current bn = next_ fused_conv = copy.deepcopy(conv) fused_conv.output[0] = bn.output[0] # Fix fused_conv to ensure we have a bias passed. if len(fused_conv.input) != 3: bias_name = "{}_bias".format(conv.input[1]) net.external_input.extend([bias_name]) fused_conv.input.extend([bias_name]) for arg in fused_conv.arg: if arg.name == "no_bias": arg.i = 0 conv_weight = params[conv.input[1]] conv_bias = params[conv.input[2]] if len(conv.input) == 3 \ else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32) bn_scale = params[bn.input[1]] bn_bias = params[bn.input[2]] bn_running_mean = params[bn.input[3]] bn_running_var = params[bn.input[4]] # First, BN computation can be phrased as follows: # (X - running_mean) * (1.0 / sqrt(running_var + eps)) * # bn_scale + bias # Thus, we can rewrite bn_scale as: # X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias - # running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale) # Thus, can just have the affine transform # X * A + B # where # A = bn_scale * 1.0 / (sqrt(running_var + eps)) # B = (bias - running_mean * (1.0 / sqrt(running_var + eps)) # * bn_scale) eps = 1.0e-5 for arg in bn.arg: if arg.name == "epsilon": eps = arg.f A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps)) B = bn_bias - bn_running_mean * A # This identify should hold if we have correctly fused # np.testing.assert_array_equal( # params[conv.output[0]] * A + B, # params[bn.output[0]]) # Now, we have that the computation made is the following: # ((X `conv` W) + b) * A + B # Then, we can simply fuse this as follows: # (X `conv` (W * A)) + b * A + B # which is simply # (X `conv` Q) + C # where # Q = W * A # C = b * A + B # For ConvTranspose, from the view of convolutions as a # Toepeliz multiplication, we have W_ = W^T, so the weights # are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv), # so the weights broadcast slightly differently. Remember, our # BN scale 'B' is of size (S,) A_ = A.reshape(-1, 1, 1, 1) if conv.type == "Conv" else \ A.reshape(1, -1, 1, 1) C = conv_bias * A + B Q = conv_weight * A_ params[fused_conv.input[1]] = Q params[fused_conv.input[2]] = C new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:] del net.op[:] removed_tensors.append(bn.input[1]) removed_tensors.append(bn.input[2]) removed_tensors.append(bn.input[3]) removed_tensors.append(bn.input[4]) del params[bn.input[1]] del params[bn.input[2]] del params[bn.input[3]] del params[bn.input[4]] net.op.extend(new_ops) break return net, params, removed_tensors def fuse_bn(net, params, ignore_failure): # Run until we hit a fixed point removed_tensors = [] while True: (next_net, next_params, removed_tensors) = \ fuse_first_bn(net, params, removed_tensors) if len(next_net.op) == len(net.op): if ( any(op.type == "SpatialBN" for op in next_net.op) and not ignore_failure ): raise Exception( "Model contains SpatialBN op after fusion: %s", next_net) return (next_net, next_params, removed_tensors) net, params, removed_tensors = (next_net, next_params, removed_tensors) def fuse_first_mul_add(net, params, removed_tensors): net = copy.deepcopy(net) params = copy.deepcopy(params) for ((i, current), (j, next_)) in pairwise(enumerate(net.op)): if current.type != "Mul" or next_.type != "Add": continue if next_.input[0] != current.output[0]: raise Exception("Failure to fuse") if len(blob_uses(net, current.output[0])) != 1: raise Exception("Failure to fuse") log.info("Fusing at index %s", i) mul_ = current add_ = next_ batch_norm = copy.deepcopy(mul_) batch_norm.type = "SpatialBN" batch_norm.arg.extend([utils.MakeArgument("is_test", 1)]) batch_norm.arg.extend([utils.MakeArgument("epsilon", float(1e-9))]) def s(x): return "{}{}".format(add_.output[0], x) fake_mean = s("_mean") fake_var = s("_var") del batch_norm.input[:] batch_norm.input.extend([mul_.input[0], mul_.input[1], add_.input[1], fake_mean, fake_var]) params[fake_mean] = np.zeros_like(params[mul_.input[1]]) params[fake_var] = np.ones_like(params[mul_.input[1]]) net.external_input.extend([fake_mean, fake_var]) batch_norm.output[0] = add_.output[0] new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:] del net.op[:] net.op.extend(new_ops) break return net, params, removed_tensors def fuse_mul_add(net, params): # Run until we hit a fixed point removed_tensors = [] while True: (next_net, next_params, removed_tensors) = \ fuse_first_mul_add(net, params, removed_tensors) if len(next_net.op) == len(net.op): return (next_net, next_params, removed_tensors) net, params, removed_tensors = (next_net, next_params, removed_tensors) def add_tensor(net, name, blob): ''' Create an operator to store the tensor 'blob', run the operator to put the blob to workspace. uint8 is stored as an array of string with one element. ''' kTypeNameMapper = { np.dtype('float32'): "GivenTensorFill", np.dtype('int32'): "GivenTensorIntFill", np.dtype('int64'): "GivenTensorInt64Fill", np.dtype('uint8'): "GivenTensorStringFill", } shape = blob.shape values = blob # pass array of uint8 as a string to save storage # storing uint8_t has a large overhead for now if blob.dtype == np.dtype('uint8'): shape = [1] values = [str(blob.data)] op = core.CreateOperator( kTypeNameMapper[blob.dtype], [], [name], arg=[ utils.MakeArgument("shape", shape), utils.MakeArgument("values", values), ] ) net.op.extend([op]) def gen_init_net_from_blobs(blobs): ''' Generate an initialization net based on a blob dict ''' ret = caffe2_pb2.NetDef() for name, blob in blobs.items(): add_tensor(ret, name, blob) return ret def fuse_conv_relu(net): net = copy.deepcopy(net) device_option = core.DeviceOption(caffe2_pb2.IDEEP) for op in net.op: op.device_option.CopyFrom(device_option) new_net = caffe2_pb2.NetDef() new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString())) return new_net def Optimize(args): init_net = caffe2_pb2.NetDef() predict_net = caffe2_pb2.NetDef() init_net.ParseFromString(args.init_net.read()) predict_net.ParseFromString(args.pred_net.read()) workspace.ResetWorkspace() workspace.RunNetOnce(init_net) param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()} external_inputs = {} external_outputs = {} if args.verify_input: value_info = json.load(args.verify_input) input_shapes = {k : v[-1] for (k, v) in value_info.items()} print("input info: {}".format(input_shapes)) for k, v in input_shapes.items(): external_inputs[k] = np.random.randn(*v).astype(np.float32) workspace.FeedBlob(k, external_inputs[k]) workspace.RunNetOnce(predict_net) for o in predict_net.external_output: external_outputs[o] = workspace.FetchBlob(o) if args.fuse_mul_add: predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict) if args.fuse_bn: predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False) if args.fuse_conv_relu: predict_net = fuse_conv_relu(predict_net) external_outputs_opt = {} if args.verify_input: workspace.ResetWorkspace() device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU) with core.DeviceScope(device_option): for k, v in param_dict.items(): workspace.FeedBlob(k, v, device_option) for k, v in external_inputs.items(): workspace.FeedBlob(k, v, device_option) workspace.RunNetOnce(predict_net) for o in predict_net.external_output: external_outputs_opt[o] = workspace.FetchBlob(o) assert np.allclose(external_outputs[o], external_outputs_opt[o], atol=1e-3, rtol=1e-3) for i, o in enumerate(predict_net.op): print("op[{}]: {}".format(i, o.type)) init_net = gen_init_net_from_blobs(param_dict) with open('init_net.pb', 'wb') as f: f.write(init_net.SerializeToString()) with open('predict_net.pb', 'wb') as f: f.write(predict_net.SerializeToString()) if __name__ == '__main__': args = GetArgumentParser().parse_args() Optimize(args)