| |
|
|
| import collections |
| import copy |
| import functools |
| import logging |
| import numpy as np |
| import os |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| from unittest import mock |
| import caffe2.python.utils as putils |
| import torch |
| import torch.nn.functional as F |
| from caffe2.proto import caffe2_pb2 |
| from caffe2.python import core, net_drawer, workspace |
| from torch.nn.functional import interpolate as interp |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
|
|
|
|
| def to_device(t, device_str): |
| """ |
| This function is a replacement of .to(another_device) such that it allows the |
| casting to be traced properly by explicitly calling the underlying copy ops. |
| It also avoids introducing unncessary op when casting to the same device. |
| """ |
| src = t.device |
| dst = torch.device(device_str) |
|
|
| if src == dst: |
| return t |
| elif src.type == "cuda" and dst.type == "cpu": |
| return torch.ops._caffe2.CopyGPUToCPU(t) |
| elif src.type == "cpu" and dst.type == "cuda": |
| return torch.ops._caffe2.CopyCPUToGPU(t) |
| else: |
| raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst)) |
|
|
|
|
| |
|
|
|
|
| |
| def BilinearInterpolation(tensor_in, up_scale): |
| assert up_scale % 2 == 0, "Scale should be even" |
|
|
| def upsample_filt(size): |
| factor = (size + 1) // 2 |
| if size % 2 == 1: |
| center = factor - 1 |
| else: |
| center = factor - 0.5 |
|
|
| og = np.ogrid[:size, :size] |
| return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) |
|
|
| kernel_size = int(up_scale) * 2 |
| bil_filt = upsample_filt(kernel_size) |
|
|
| dim = int(tensor_in.shape[1]) |
| kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32) |
| kernel[range(dim), range(dim), :, :] = bil_filt |
|
|
| tensor_out = F.conv_transpose2d( |
| tensor_in, |
| weight=to_device(torch.Tensor(kernel), tensor_in.device), |
| bias=None, |
| stride=int(up_scale), |
| padding=int(up_scale / 2), |
| ) |
|
|
| return tensor_out |
|
|
|
|
| |
| |
| |
| def onnx_compatibale_interpolate( |
| input, size=None, scale_factor=None, mode="nearest", align_corners=None |
| ): |
| |
| |
| if size is None and scale_factor is not None: |
| if input.dim() == 4: |
| if isinstance(scale_factor, (int, float)): |
| height_scale, width_scale = (scale_factor, scale_factor) |
| else: |
| assert isinstance(scale_factor, (tuple, list)) |
| assert len(scale_factor) == 2 |
| height_scale, width_scale = scale_factor |
|
|
| assert not align_corners, "No matching C2 op for align_corners == True" |
| if mode == "nearest": |
| return torch.ops._caffe2.ResizeNearest( |
| input, order="NCHW", width_scale=width_scale, height_scale=height_scale |
| ) |
| elif mode == "bilinear": |
| logger.warning( |
| "Use F.conv_transpose2d for bilinear interpolate" |
| " because there's no such C2 op, this may cause significant" |
| " slowdown and the boundary pixels won't be as same as" |
| " using F.interpolate due to padding." |
| ) |
| assert height_scale == width_scale |
| return BilinearInterpolation(input, up_scale=height_scale) |
| logger.warning("Output size is not static, it might cause ONNX conversion issue") |
|
|
| return interp(input, size, scale_factor, mode, align_corners) |
|
|
|
|
| def mock_torch_nn_functional_interpolate(): |
| def decorator(func): |
| @functools.wraps(func) |
| def _mock_torch_nn_functional_interpolate(*args, **kwargs): |
| if torch.onnx.is_in_onnx_export(): |
| with mock.patch( |
| "torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate |
| ): |
| return func(*args, **kwargs) |
| else: |
| return func(*args, **kwargs) |
|
|
| return _mock_torch_nn_functional_interpolate |
|
|
| return decorator |
|
|
|
|
| |
|
|
|
|
| class ScopedWS: |
| def __init__(self, ws_name, is_reset, is_cleanup=False): |
| self.ws_name = ws_name |
| self.is_reset = is_reset |
| self.is_cleanup = is_cleanup |
| self.org_ws = "" |
|
|
| def __enter__(self): |
| self.org_ws = workspace.CurrentWorkspace() |
| if self.ws_name is not None: |
| workspace.SwitchWorkspace(self.ws_name, True) |
| if self.is_reset: |
| workspace.ResetWorkspace() |
|
|
| return workspace |
|
|
| def __exit__(self, *args): |
| if self.is_cleanup: |
| workspace.ResetWorkspace() |
| if self.ws_name is not None: |
| workspace.SwitchWorkspace(self.org_ws) |
|
|
|
|
| def fetch_any_blob(name): |
| bb = None |
| try: |
| bb = workspace.FetchBlob(name) |
| except TypeError: |
| bb = workspace.FetchInt8Blob(name) |
| except Exception as e: |
| logger.error("Get blob {} error: {}".format(name, e)) |
|
|
| return bb |
|
|
|
|
| |
|
|
|
|
| def get_pb_arg(pb, arg_name): |
| for x in pb.arg: |
| if x.name == arg_name: |
| return x |
| return None |
|
|
|
|
| def get_pb_arg_valf(pb, arg_name, default_val): |
| arg = get_pb_arg(pb, arg_name) |
| return arg.f if arg is not None else default_val |
|
|
|
|
| def get_pb_arg_floats(pb, arg_name, default_val): |
| arg = get_pb_arg(pb, arg_name) |
| return list(map(float, arg.floats)) if arg is not None else default_val |
|
|
|
|
| def get_pb_arg_ints(pb, arg_name, default_val): |
| arg = get_pb_arg(pb, arg_name) |
| return list(map(int, arg.ints)) if arg is not None else default_val |
|
|
|
|
| def get_pb_arg_vali(pb, arg_name, default_val): |
| arg = get_pb_arg(pb, arg_name) |
| return arg.i if arg is not None else default_val |
|
|
|
|
| def get_pb_arg_vals(pb, arg_name, default_val): |
| arg = get_pb_arg(pb, arg_name) |
| return arg.s if arg is not None else default_val |
|
|
|
|
| def get_pb_arg_valstrings(pb, arg_name, default_val): |
| arg = get_pb_arg(pb, arg_name) |
| return list(arg.strings) if arg is not None else default_val |
|
|
|
|
| def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False): |
| arg = get_pb_arg(pb, arg_name) |
| if arg is None: |
| arg = putils.MakeArgument(arg_name, arg_value) |
| assert hasattr(arg, arg_attr) |
| pb.arg.extend([arg]) |
| if allow_override and getattr(arg, arg_attr) != arg_value: |
| logger.warning( |
| "Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value) |
| ) |
| setattr(arg, arg_attr, arg_value) |
| else: |
| assert arg is not None |
| assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format( |
| getattr(arg, arg_attr), arg_value |
| ) |
|
|
|
|
| def _create_const_fill_op_from_numpy(name, tensor, device_option=None): |
| assert type(tensor) == np.ndarray |
| kTypeNameMapper = { |
| np.dtype("float32"): "GivenTensorFill", |
| np.dtype("int32"): "GivenTensorIntFill", |
| np.dtype("int64"): "GivenTensorInt64Fill", |
| np.dtype("uint8"): "GivenTensorStringFill", |
| } |
|
|
| args_dict = {} |
| if tensor.dtype == np.dtype("uint8"): |
| args_dict.update({"values": [str(tensor.data)], "shape": [1]}) |
| else: |
| args_dict.update({"values": tensor, "shape": tensor.shape}) |
|
|
| if device_option is not None: |
| args_dict["device_option"] = device_option |
|
|
| return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict) |
|
|
|
|
| def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor): |
| assert type(int8_tensor) == workspace.Int8Tensor |
| kTypeNameMapper = { |
| np.dtype("int32"): "Int8GivenIntTensorFill", |
| np.dtype("uint8"): "Int8GivenTensorFill", |
| } |
|
|
| tensor = int8_tensor.data |
| assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")] |
| values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor |
|
|
| return core.CreateOperator( |
| kTypeNameMapper[tensor.dtype], |
| [], |
| [name], |
| values=values, |
| shape=tensor.shape, |
| Y_scale=int8_tensor.scale, |
| Y_zero_point=int8_tensor.zero_point, |
| ) |
|
|
|
|
| def create_const_fill_op( |
| name: str, |
| blob: Union[np.ndarray, workspace.Int8Tensor], |
| device_option: Optional[caffe2_pb2.DeviceOption] = None, |
| ) -> caffe2_pb2.OperatorDef: |
| """ |
| Given a blob object, return the Caffe2 operator that creates this blob |
| as constant. Currently support NumPy tensor and Caffe2 Int8Tensor. |
| """ |
|
|
| tensor_type = type(blob) |
| assert tensor_type in [ |
| np.ndarray, |
| workspace.Int8Tensor, |
| ], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format( |
| name, type(blob) |
| ) |
|
|
| if tensor_type == np.ndarray: |
| return _create_const_fill_op_from_numpy(name, blob, device_option) |
| elif tensor_type == workspace.Int8Tensor: |
| assert device_option is None |
| return _create_const_fill_op_from_c2_int8_tensor(name, blob) |
|
|
|
|
| def construct_init_net_from_params( |
| params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None |
| ) -> caffe2_pb2.NetDef: |
| """ |
| Construct the init_net from params dictionary |
| """ |
| init_net = caffe2_pb2.NetDef() |
| device_options = device_options or {} |
| for name, blob in params.items(): |
| if isinstance(blob, str): |
| logger.warning( |
| ( |
| "Blob {} with type {} is not supported in generating init net," |
| " skipped.".format(name, type(blob)) |
| ) |
| ) |
| continue |
| init_net.op.extend( |
| [create_const_fill_op(name, blob, device_option=device_options.get(name, None))] |
| ) |
| init_net.external_output.append(name) |
| return init_net |
|
|
|
|
| def get_producer_map(ssa): |
| """ |
| Return dict from versioned blob to (i, j), |
| where i is index of producer op, j is the index of output of that op. |
| """ |
| producer_map = {} |
| for i in range(len(ssa)): |
| outputs = ssa[i][1] |
| for j, outp in enumerate(outputs): |
| producer_map[outp] = (i, j) |
| return producer_map |
|
|
|
|
| def get_consumer_map(ssa): |
| """ |
| Return dict from versioned blob to list of (i, j), |
| where i is index of consumer op, j is the index of input of that op. |
| """ |
| consumer_map = collections.defaultdict(list) |
| for i in range(len(ssa)): |
| inputs = ssa[i][0] |
| for j, inp in enumerate(inputs): |
| consumer_map[inp].append((i, j)) |
| return consumer_map |
|
|
|
|
| def get_params_from_init_net( |
| init_net: caffe2_pb2.NetDef, |
| ) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]: |
| """ |
| Take the output blobs from init_net by running it. |
| Outputs: |
| params: dict from blob name to numpy array |
| device_options: dict from blob name to the device option of its creating op |
| """ |
| |
| |
| def _get_device_option(producer_op): |
| if producer_op.type == "CopyGPUToCPU": |
| return caffe2_pb2.DeviceOption() |
| else: |
| return producer_op.device_option |
|
|
| with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws: |
| ws.RunNetOnce(init_net) |
| params = {b: fetch_any_blob(b) for b in init_net.external_output} |
| ssa, versions = core.get_ssa(init_net) |
| producer_map = get_producer_map(ssa) |
| device_options = { |
| b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]]) |
| for b in init_net.external_output |
| } |
| return params, device_options |
|
|
|
|
| def _updater_raise(op, input_types, output_types): |
| raise RuntimeError( |
| "Failed to apply updater for op {} given input_types {} and" |
| " output_types {}".format(op, input_types, output_types) |
| ) |
|
|
|
|
| def _generic_status_identifier( |
| predict_net: caffe2_pb2.NetDef, |
| status_updater: Callable, |
| known_status: Dict[Tuple[str, int], Any], |
| ) -> Dict[Tuple[str, int], Any]: |
| """ |
| Statically infer the status of each blob, the status can be such as device type |
| (CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here |
| is versioned blob (Tuple[str, int]) in the format compatible with ssa. |
| Inputs: |
| predict_net: the caffe2 network |
| status_updater: a callable, given an op and the status of its input/output, |
| it returns the updated status of input/output. `None` is used for |
| representing unknown status. |
| known_status: a dict containing known status, used as initialization. |
| Outputs: |
| A dict mapping from versioned blob to its status |
| """ |
| ssa, versions = core.get_ssa(predict_net) |
| versioned_ext_input = [(b, 0) for b in predict_net.external_input] |
| versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output] |
| all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa]) |
|
|
| allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output) |
| assert all(k in allowed_vbs for k in known_status) |
| assert all(v is not None for v in known_status.values()) |
| _known_status = copy.deepcopy(known_status) |
|
|
| def _check_and_update(key, value): |
| assert value is not None |
| if key in _known_status: |
| if not _known_status[key] == value: |
| raise RuntimeError( |
| "Confilict status for {}, existing status {}, new status {}".format( |
| key, _known_status[key], value |
| ) |
| ) |
| _known_status[key] = value |
|
|
| def _update_i(op, ssa_i): |
| versioned_inputs = ssa_i[0] |
| versioned_outputs = ssa_i[1] |
|
|
| inputs_status = [_known_status.get(b, None) for b in versioned_inputs] |
| outputs_status = [_known_status.get(b, None) for b in versioned_outputs] |
|
|
| new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status) |
|
|
| for versioned_blob, status in zip( |
| versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status |
| ): |
| if status is not None: |
| _check_and_update(versioned_blob, status) |
|
|
| for op, ssa_i in zip(predict_net.op, ssa): |
| _update_i(op, ssa_i) |
| for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)): |
| _update_i(op, ssa_i) |
|
|
| |
| |
| |
| for k in all_versioned_blobs: |
| if k not in _known_status: |
| raise NotImplementedError( |
| "Can not infer the status for {}. Currently only support the case where" |
| " a single forward and backward pass can identify status for all blobs.".format(k) |
| ) |
|
|
| return _known_status |
|
|
|
|
| def infer_device_type( |
| predict_net: caffe2_pb2.NetDef, |
| known_status: Dict[Tuple[str, int], Any], |
| device_name_style: str = "caffe2", |
| ) -> Dict[Tuple[str, int], str]: |
| """Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob""" |
|
|
| assert device_name_style in ["caffe2", "pytorch"] |
| _CPU_STR = "cpu" |
| _GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda" |
|
|
| def _copy_cpu_to_gpu_updater(op, input_types, output_types): |
| if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR: |
| _updater_raise(op, input_types, output_types) |
| return ([_CPU_STR], [_GPU_STR]) |
|
|
| def _copy_gpu_to_cpu_updater(op, input_types, output_types): |
| if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR: |
| _updater_raise(op, input_types, output_types) |
| return ([_GPU_STR], [_CPU_STR]) |
|
|
| def _other_ops_updater(op, input_types, output_types): |
| non_none_types = [x for x in input_types + output_types if x is not None] |
| if len(non_none_types) > 0: |
| the_type = non_none_types[0] |
| if not all(x == the_type for x in non_none_types): |
| _updater_raise(op, input_types, output_types) |
| else: |
| the_type = None |
| return ([the_type for _ in op.input], [the_type for _ in op.output]) |
|
|
| def _device_updater(op, *args, **kwargs): |
| return { |
| "CopyCPUToGPU": _copy_cpu_to_gpu_updater, |
| "CopyGPUToCPU": _copy_gpu_to_cpu_updater, |
| }.get(op.type, _other_ops_updater)(op, *args, **kwargs) |
|
|
| return _generic_status_identifier(predict_net, _device_updater, known_status) |
|
|
|
|
| |
|
|
|
|
| def _modify_blob_names(ops, blob_rename_f): |
| ret = [] |
|
|
| def _replace_list(blob_list, replaced_list): |
| del blob_list[:] |
| blob_list.extend(replaced_list) |
|
|
| for x in ops: |
| cur = copy.deepcopy(x) |
| _replace_list(cur.input, list(map(blob_rename_f, cur.input))) |
| _replace_list(cur.output, list(map(blob_rename_f, cur.output))) |
| ret.append(cur) |
|
|
| return ret |
|
|
|
|
| def _rename_blob(name, blob_sizes, blob_ranges): |
| def _list_to_str(bsize): |
| ret = ", ".join([str(x) for x in bsize]) |
| ret = "[" + ret + "]" |
| return ret |
|
|
| ret = name |
| if blob_sizes is not None and name in blob_sizes: |
| ret += "\n" + _list_to_str(blob_sizes[name]) |
| if blob_ranges is not None and name in blob_ranges: |
| ret += "\n" + _list_to_str(blob_ranges[name]) |
|
|
| return ret |
|
|
|
|
| |
| def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None): |
| blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges) |
| return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f) |
|
|
|
|
| def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None): |
| graph = None |
| ops = net.op |
| if blob_rename_func is not None: |
| ops = _modify_blob_names(ops, blob_rename_func) |
| if not op_only: |
| graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB") |
| else: |
| graph = net_drawer.GetPydotGraphMinimal( |
| ops, graph_name, rankdir="TB", minimal_dependency=True |
| ) |
|
|
| try: |
| par_dir = os.path.dirname(file_name) |
| if not os.path.exists(par_dir): |
| os.makedirs(par_dir) |
|
|
| format = os.path.splitext(os.path.basename(file_name))[-1] |
| if format == ".png": |
| graph.write_png(file_name) |
| elif format == ".pdf": |
| graph.write_pdf(file_name) |
| elif format == ".svg": |
| graph.write_svg(file_name) |
| else: |
| print("Incorrect format {}".format(format)) |
| except Exception as e: |
| print("Error when writing graph to image {}".format(e)) |
|
|
| return graph |
|
|
|
|
| |
|
|
|
|
| def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef): |
| """ |
| For ONNX exported model, GroupNorm will be represented as ATen op, |
| this can be a drop in replacement from ATen to GroupNorm |
| """ |
| count = 0 |
| for op in predict_net.op: |
| if op.type == "ATen": |
| op_name = get_pb_arg_vals(op, "operator", None) |
| if op_name and op_name.decode() == "group_norm": |
| op.arg.remove(get_pb_arg(op, "operator")) |
|
|
| if get_pb_arg_vali(op, "cudnn_enabled", None): |
| op.arg.remove(get_pb_arg(op, "cudnn_enabled")) |
|
|
| num_groups = get_pb_arg_vali(op, "num_groups", None) |
| if num_groups is not None: |
| op.arg.remove(get_pb_arg(op, "num_groups")) |
| check_set_pb_arg(op, "group", "i", num_groups) |
|
|
| op.type = "GroupNorm" |
| count += 1 |
| if count > 1: |
| logger.info("Replaced {} ATen operator to GroupNormOp".format(count)) |
|
|
|
|
| |
|
|
|
|
| def alias(x, name, is_backward=False): |
| if not torch.onnx.is_in_onnx_export(): |
| return x |
| assert isinstance(x, torch.Tensor) |
| return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward) |
|
|
|
|
| def fuse_alias_placeholder(predict_net, init_net): |
| """Remove AliasWithName placeholder and rename the input/output of it""" |
| |
| for i, op in enumerate(predict_net.op): |
| if op.type == "AliasWithName": |
| assert len(op.input) == 1 |
| assert len(op.output) == 1 |
| name = get_pb_arg_vals(op, "name", None).decode() |
| is_backward = bool(get_pb_arg_vali(op, "is_backward", 0)) |
| rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward) |
| rename_op_output(predict_net, i, 0, name) |
|
|
| |
| new_ops = [] |
| for op in predict_net.op: |
| if op.type != "AliasWithName": |
| new_ops.append(op) |
| else: |
| |
| assert op.input == op.output |
| assert op.input[0] == op.arg[0].s.decode() |
| del predict_net.op[:] |
| predict_net.op.extend(new_ops) |
|
|
|
|
| |
|
|
|
|
| class IllegalGraphTransformError(ValueError): |
| """When a graph transform function call can't be executed.""" |
|
|
|
|
| def _rename_versioned_blob_in_proto( |
| proto: caffe2_pb2.NetDef, |
| old_name: str, |
| new_name: str, |
| version: int, |
| ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]], |
| start_versions: Dict[str, int], |
| end_versions: Dict[str, int], |
| ): |
| """In given proto, rename all blobs with matched version""" |
| |
| for op, i_th_ssa in zip(proto.op, ssa): |
| versioned_inputs, versioned_outputs = i_th_ssa |
| for i in range(len(op.input)): |
| if versioned_inputs[i] == (old_name, version): |
| op.input[i] = new_name |
| for i in range(len(op.output)): |
| if versioned_outputs[i] == (old_name, version): |
| op.output[i] = new_name |
| |
| if start_versions.get(old_name, 0) == version: |
| for i in range(len(proto.external_input)): |
| if proto.external_input[i] == old_name: |
| proto.external_input[i] = new_name |
| |
| if end_versions.get(old_name, 0) == version: |
| for i in range(len(proto.external_output)): |
| if proto.external_output[i] == old_name: |
| proto.external_output[i] = new_name |
|
|
|
|
| def rename_op_input( |
| predict_net: caffe2_pb2.NetDef, |
| init_net: caffe2_pb2.NetDef, |
| op_id: int, |
| input_id: int, |
| new_name: str, |
| from_producer: bool = False, |
| ): |
| """ |
| Rename the op_id-th operator in predict_net, change it's input_id-th input's |
| name to the new_name. It also does automatic re-route and change |
| external_input and init_net if necessary. |
| - It requires the input is only consumed by this op. |
| - This function modifies predict_net and init_net in-place. |
| - When from_producer is enable, this also updates other operators that consumes |
| the same input. Be cautious because may trigger unintended behavior. |
| """ |
| assert isinstance(predict_net, caffe2_pb2.NetDef) |
| assert isinstance(init_net, caffe2_pb2.NetDef) |
|
|
| init_net_ssa, init_net_versions = core.get_ssa(init_net) |
| predict_net_ssa, predict_net_versions = core.get_ssa( |
| predict_net, copy.deepcopy(init_net_versions) |
| ) |
|
|
| versioned_inputs, versioned_outputs = predict_net_ssa[op_id] |
| old_name, version = versioned_inputs[input_id] |
|
|
| if from_producer: |
| producer_map = get_producer_map(predict_net_ssa) |
| if not (old_name, version) in producer_map: |
| raise NotImplementedError( |
| "Can't find producer, the input {} is probably from" |
| " init_net, this is not supported yet.".format(old_name) |
| ) |
| producer = producer_map[(old_name, version)] |
| rename_op_output(predict_net, producer[0], producer[1], new_name) |
| return |
|
|
| def contain_targets(op_ssa): |
| return (old_name, version) in op_ssa[0] |
|
|
| is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa] |
| if sum(is_consumer) > 1: |
| raise IllegalGraphTransformError( |
| ( |
| "Input '{}' of operator(#{}) are consumed by other ops, please use" |
| + " rename_op_output on the producer instead. Offending op: \n{}" |
| ).format(old_name, op_id, predict_net.op[op_id]) |
| ) |
|
|
| |
| _rename_versioned_blob_in_proto( |
| init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions |
| ) |
| |
| _rename_versioned_blob_in_proto( |
| predict_net, |
| old_name, |
| new_name, |
| version, |
| predict_net_ssa, |
| init_net_versions, |
| predict_net_versions, |
| ) |
|
|
|
|
| def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str): |
| """ |
| Rename the op_id-th operator in predict_net, change it's output_id-th input's |
| name to the new_name. It also does automatic re-route and change |
| external_output and if necessary. |
| - It allows multiple consumers of its output. |
| - This function modifies predict_net in-place, doesn't need init_net. |
| """ |
| assert isinstance(predict_net, caffe2_pb2.NetDef) |
|
|
| ssa, blob_versions = core.get_ssa(predict_net) |
|
|
| versioned_inputs, versioned_outputs = ssa[op_id] |
| old_name, version = versioned_outputs[output_id] |
|
|
| |
| _rename_versioned_blob_in_proto( |
| predict_net, old_name, new_name, version, ssa, {}, blob_versions |
| ) |
|
|
|
|
| def get_sub_graph_external_input_output( |
| predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int] |
| ) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]: |
| """ |
| Return the list of external input/output of sub-graph, |
| each element is tuple of the name and corresponding version in predict_net. |
| |
| external input/output is defined the same way as caffe2 NetDef. |
| """ |
| ssa, versions = core.get_ssa(predict_net) |
|
|
| all_inputs = [] |
| all_outputs = [] |
| for op_id in sub_graph_op_indices: |
| all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs] |
| all_outputs += list(ssa[op_id][1]) |
|
|
| |
| |
| ext_inputs = [inp for inp in all_inputs if inp not in all_outputs] |
|
|
| |
| |
| all_other_inputs = sum( |
| (ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices), |
| [(outp, versions[outp]) for outp in predict_net.external_output], |
| ) |
| ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)] |
|
|
| return ext_inputs, ext_outputs |
|
|
|
|
| class DiGraph: |
| """A DAG representation of caffe2 graph, each vertice is a versioned blob.""" |
|
|
| def __init__(self): |
| self.vertices = set() |
| self.graph = collections.defaultdict(list) |
|
|
| def add_edge(self, u, v): |
| self.graph[u].append(v) |
| self.vertices.add(u) |
| self.vertices.add(v) |
|
|
| |
| def get_all_paths(self, s, d): |
| visited = {k: False for k in self.vertices} |
| path = [] |
| all_paths = [] |
|
|
| def _get_all_paths_util(graph, u, d, visited, path): |
| visited[u] = True |
| path.append(u) |
| if u == d: |
| all_paths.append(copy.deepcopy(path)) |
| else: |
| for i in graph[u]: |
| if not visited[i]: |
| _get_all_paths_util(graph, i, d, visited, path) |
| path.pop() |
| visited[u] = False |
|
|
| _get_all_paths_util(self.graph, s, d, visited, path) |
| return all_paths |
|
|
| @staticmethod |
| def from_ssa(ssa): |
| graph = DiGraph() |
| for op_id in range(len(ssa)): |
| for inp in ssa[op_id][0]: |
| for outp in ssa[op_id][1]: |
| graph.add_edge(inp, outp) |
| return graph |
|
|
|
|
| def _get_dependency_chain(ssa, versioned_target, versioned_source): |
| """ |
| Return the index list of relevant operator to produce target blob from source blob, |
| if there's no dependency, return empty list. |
| """ |
|
|
| |
| |
| |
| consumer_map = get_consumer_map(ssa) |
| producer_map = get_producer_map(ssa) |
| start_op = min(x[0] for x in consumer_map[versioned_source]) - 15 |
| end_op = ( |
| producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op |
| ) |
| sub_graph_ssa = ssa[start_op : end_op + 1] |
| if len(sub_graph_ssa) > 30: |
| logger.warning( |
| "Subgraph bebetween {} and {} is large (from op#{} to op#{}), it" |
| " might take non-trival time to find all paths between them.".format( |
| versioned_source, versioned_target, start_op, end_op |
| ) |
| ) |
|
|
| dag = DiGraph.from_ssa(sub_graph_ssa) |
| paths = dag.get_all_paths(versioned_source, versioned_target) |
| ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths] |
| return sorted(set().union(*[set(ops) for ops in ops_in_paths])) |
|
|
|
|
| def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]: |
| """ |
| Idenfity the reshape sub-graph in a protobuf. |
| The reshape sub-graph is defined as matching the following pattern: |
| |
| (input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐ |
| └-------------------------------------------> Reshape -> (output_blob) |
| |
| Return: |
| List of sub-graphs, each sub-graph is represented as a list of indices |
| of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape] |
| """ |
|
|
| ssa, _ = core.get_ssa(predict_net) |
|
|
| ret = [] |
| for i, op in enumerate(predict_net.op): |
| if op.type == "Reshape": |
| assert len(op.input) == 2 |
| input_ssa = ssa[i][0] |
| data_source = input_ssa[0] |
| shape_source = input_ssa[1] |
| op_indices = _get_dependency_chain(ssa, shape_source, data_source) |
| ret.append(op_indices + [i]) |
| return ret |
|
|
|
|
| def remove_reshape_for_fc(predict_net, params): |
| """ |
| In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape |
| a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping |
| doesn't work well with ONNX and Int8 tools, and cause using extra |
| ops (eg. ExpandDims) that might not be available on mobile. |
| Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape |
| after exporting ONNX model. |
| """ |
| from caffe2.python import core |
|
|
| |
| |
| |
| |
| reshape_sub_graphs = identify_reshape_sub_graph(predict_net) |
| sub_graphs_to_remove = [] |
| for reshape_sub_graph in reshape_sub_graphs: |
| reshape_op_id = reshape_sub_graph[-1] |
| assert predict_net.op[reshape_op_id].type == "Reshape" |
| ssa, _ = core.get_ssa(predict_net) |
| reshape_output = ssa[reshape_op_id][1][0] |
| consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]] |
| if all(predict_net.op[consumer].type == "FC" for consumer in consumers): |
| |
| |
| ext_inputs, ext_outputs = get_sub_graph_external_input_output( |
| predict_net, reshape_sub_graph |
| ) |
| non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0] |
| if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1: |
| sub_graphs_to_remove.append(reshape_sub_graph) |
|
|
| |
| |
| |
| |
| remove_op_ids = [] |
| params_to_remove = [] |
| for sub_graph in sub_graphs_to_remove: |
| logger.info( |
| "Remove Reshape sub-graph:\n{}".format( |
| "".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph]) |
| ) |
| ) |
| reshape_op_id = sub_graph[-1] |
| new_reshap_output = predict_net.op[reshape_op_id].input[0] |
| rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output) |
| ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph) |
| non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0] |
| params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0] |
| assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1 |
| assert ext_outputs[0][0] == non_params_ext_inputs[0][0] |
| assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1 |
| remove_op_ids.extend(sub_graph) |
| params_to_remove.extend(params_ext_inputs) |
|
|
| predict_net = copy.deepcopy(predict_net) |
| new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids] |
| del predict_net.op[:] |
| predict_net.op.extend(new_ops) |
| for versioned_params in params_to_remove: |
| name = versioned_params[0] |
| logger.info("Remove params: {} from init_net and predict_net.external_input".format(name)) |
| del params[name] |
| predict_net.external_input.remove(name) |
|
|
| return predict_net, params |
|
|
|
|
| def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef): |
| """ |
| In-place fuse extra copy ops between cpu/gpu for the following case: |
| a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1 |
| -CopyBToA> c2 -NextOp2-> d2 |
| The fused network will look like: |
| a -NextOp1-> d1 |
| -NextOp2-> d2 |
| """ |
|
|
| _COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"] |
|
|
| def _fuse_once(predict_net): |
| ssa, blob_versions = core.get_ssa(predict_net) |
| consumer_map = get_consumer_map(ssa) |
| versioned_external_output = [ |
| (name, blob_versions[name]) for name in predict_net.external_output |
| ] |
|
|
| for op_id, op in enumerate(predict_net.op): |
| if op.type in _COPY_OPS: |
| fw_copy_versioned_output = ssa[op_id][1][0] |
| consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]] |
| reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)] |
|
|
| is_fusable = ( |
| len(consumer_ids) > 0 |
| and fw_copy_versioned_output not in versioned_external_output |
| and all( |
| predict_net.op[_op_id].type == reverse_op_type |
| and ssa[_op_id][1][0] not in versioned_external_output |
| for _op_id in consumer_ids |
| ) |
| ) |
|
|
| if is_fusable: |
| for rv_copy_op_id in consumer_ids: |
| |
| rs_copy_versioned_output = ssa[rv_copy_op_id][1][0] |
| next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0] |
| predict_net.op[next_op_id].input[inp_id] = op.input[0] |
| |
| new_ops = [ |
| op |
| for i, op in enumerate(predict_net.op) |
| if i != op_id and i not in consumer_ids |
| ] |
| del predict_net.op[:] |
| predict_net.op.extend(new_ops) |
| return True |
|
|
| return False |
|
|
| |
| while _fuse_once(predict_net): |
| pass |
|
|
|
|
| def remove_dead_end_ops(net_def: caffe2_pb2.NetDef): |
| """remove ops if its output is not used or not in external_output""" |
| ssa, versions = core.get_ssa(net_def) |
| versioned_external_output = [(name, versions[name]) for name in net_def.external_output] |
| consumer_map = get_consumer_map(ssa) |
| removed_op_ids = set() |
|
|
| def _is_dead_end(versioned_blob): |
| return not ( |
| versioned_blob in versioned_external_output |
| or ( |
| len(consumer_map[versioned_blob]) > 0 |
| and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob]) |
| ) |
| ) |
|
|
| for i, ssa_i in reversed(list(enumerate(ssa))): |
| versioned_outputs = ssa_i[1] |
| if all(_is_dead_end(outp) for outp in versioned_outputs): |
| removed_op_ids.add(i) |
|
|
| |
| new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids] |
| del net_def.op[:] |
| net_def.op.extend(new_ops) |
|
|