|
|
| import collections
|
| from dataclasses import dataclass
|
| from typing import Callable, List, Optional, Tuple
|
| import torch
|
| from torch import nn
|
|
|
| from detectron2.structures import Boxes, Instances, ROIMasks
|
| from detectron2.utils.registry import _convert_target_to_string, locate
|
|
|
| from .torchscript_patch import patch_builtin_len
|
|
|
|
|
| @dataclass
|
| class Schema:
|
| """
|
| A Schema defines how to flatten a possibly hierarchical object into tuple of
|
| primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
|
|
|
| PyTorch does not support tracing a function that produces rich output
|
| structures (e.g. dict, Instances, Boxes). To trace such a function, we
|
| flatten the rich object into tuple of tensors, and return this tuple of tensors
|
| instead. Meanwhile, we also need to know how to "rebuild" the original object
|
| from the flattened results, so we can evaluate the flattened results.
|
| A Schema defines how to flatten an object, and while flattening it, it records
|
| necessary schemas so that the object can be rebuilt using the flattened outputs.
|
|
|
| The flattened object and the schema object is returned by ``.flatten`` classmethod.
|
| Then the original object can be rebuilt with the ``__call__`` method of schema.
|
|
|
| A Schema is a dataclass that can be serialized easily.
|
| """
|
|
|
|
|
|
|
| @classmethod
|
| def flatten(cls, obj):
|
| raise NotImplementedError
|
|
|
| def __call__(self, values):
|
| raise NotImplementedError
|
|
|
| @staticmethod
|
| def _concat(values):
|
| ret = ()
|
| sizes = []
|
| for v in values:
|
| assert isinstance(v, tuple), "Flattened results must be a tuple"
|
| ret = ret + v
|
| sizes.append(len(v))
|
| return ret, sizes
|
|
|
| @staticmethod
|
| def _split(values, sizes):
|
| if len(sizes):
|
| expected_len = sum(sizes)
|
| assert (
|
| len(values) == expected_len
|
| ), f"Values has length {len(values)} but expect length {expected_len}."
|
| ret = []
|
| for k in range(len(sizes)):
|
| begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
|
| ret.append(values[begin:end])
|
| return ret
|
|
|
|
|
| @dataclass
|
| class ListSchema(Schema):
|
| schemas: List[Schema]
|
| sizes: List[int]
|
|
|
| def __call__(self, values):
|
| values = self._split(values, self.sizes)
|
| if len(values) != len(self.schemas):
|
| raise ValueError(
|
| f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
|
| )
|
| values = [m(v) for m, v in zip(self.schemas, values)]
|
| return list(values)
|
|
|
| @classmethod
|
| def flatten(cls, obj):
|
| res = [flatten_to_tuple(k) for k in obj]
|
| values, sizes = cls._concat([k[0] for k in res])
|
| return values, cls([k[1] for k in res], sizes)
|
|
|
|
|
| @dataclass
|
| class TupleSchema(ListSchema):
|
| def __call__(self, values):
|
| return tuple(super().__call__(values))
|
|
|
|
|
| @dataclass
|
| class IdentitySchema(Schema):
|
| def __call__(self, values):
|
| return values[0]
|
|
|
| @classmethod
|
| def flatten(cls, obj):
|
| return (obj,), cls()
|
|
|
|
|
| @dataclass
|
| class DictSchema(ListSchema):
|
| keys: List[str]
|
|
|
| def __call__(self, values):
|
| values = super().__call__(values)
|
| return dict(zip(self.keys, values))
|
|
|
| @classmethod
|
| def flatten(cls, obj):
|
| for k in obj.keys():
|
| if not isinstance(k, str):
|
| raise KeyError("Only support flattening dictionaries if keys are str.")
|
| keys = sorted(obj.keys())
|
| values = [obj[k] for k in keys]
|
| ret, schema = ListSchema.flatten(values)
|
| return ret, cls(schema.schemas, schema.sizes, keys)
|
|
|
|
|
| @dataclass
|
| class InstancesSchema(DictSchema):
|
| def __call__(self, values):
|
| image_size, fields = values[-1], values[:-1]
|
| fields = super().__call__(fields)
|
| return Instances(image_size, **fields)
|
|
|
| @classmethod
|
| def flatten(cls, obj):
|
| ret, schema = super().flatten(obj.get_fields())
|
| size = obj.image_size
|
| if not isinstance(size, torch.Tensor):
|
| size = torch.tensor(size)
|
| return ret + (size,), schema
|
|
|
|
|
| @dataclass
|
| class TensorWrapSchema(Schema):
|
| """
|
| For classes that are simple wrapper of tensors, e.g.
|
| Boxes, RotatedBoxes, BitMasks
|
| """
|
|
|
| class_name: str
|
|
|
| def __call__(self, values):
|
| return locate(self.class_name)(values[0])
|
|
|
| @classmethod
|
| def flatten(cls, obj):
|
| return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
|
|
|
|
|
|
|
|
|
| def flatten_to_tuple(obj):
|
| """
|
| Flatten an object so it can be used for PyTorch tracing.
|
| Also returns how to rebuild the original object from the flattened outputs.
|
|
|
| Returns:
|
| res (tuple): the flattened results that can be used as tracing outputs
|
| schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
|
| It is a pure dataclass that can be serialized.
|
| """
|
| schemas = [
|
| ((str, bytes), IdentitySchema),
|
| (list, ListSchema),
|
| (tuple, TupleSchema),
|
| (collections.abc.Mapping, DictSchema),
|
| (Instances, InstancesSchema),
|
| ((Boxes, ROIMasks), TensorWrapSchema),
|
| ]
|
| for klass, schema in schemas:
|
| if isinstance(obj, klass):
|
| F = schema
|
| break
|
| else:
|
| F = IdentitySchema
|
|
|
| return F.flatten(obj)
|
|
|
|
|
| class TracingAdapter(nn.Module):
|
| """
|
| A model may take rich input/output format (e.g. dict or custom classes),
|
| but `torch.jit.trace` requires tuple of tensors as input/output.
|
| This adapter flattens input/output format of a model so it becomes traceable.
|
|
|
| It also records the necessary schema to rebuild model's inputs/outputs from flattened
|
| inputs/outputs.
|
|
|
| Example:
|
| ::
|
| outputs = model(inputs) # inputs/outputs may be rich structure
|
| adapter = TracingAdapter(model, inputs)
|
|
|
| # can now trace the model, with adapter.flattened_inputs, or another
|
| # tuple of tensors with the same length and meaning
|
| traced = torch.jit.trace(adapter, adapter.flattened_inputs)
|
|
|
| # traced model can only produce flattened outputs (tuple of tensors)
|
| flattened_outputs = traced(*adapter.flattened_inputs)
|
| # adapter knows the schema to convert it back (new_outputs == outputs)
|
| new_outputs = adapter.outputs_schema(flattened_outputs)
|
| """
|
|
|
| flattened_inputs: Tuple[torch.Tensor] = None
|
| """
|
| Flattened version of inputs given to this class's constructor.
|
| """
|
|
|
| inputs_schema: Schema = None
|
| """
|
| Schema of the inputs given to this class's constructor.
|
| """
|
|
|
| outputs_schema: Schema = None
|
| """
|
| Schema of the output produced by calling the given model with inputs.
|
| """
|
|
|
| def __init__(
|
| self,
|
| model: nn.Module,
|
| inputs,
|
| inference_func: Optional[Callable] = None,
|
| allow_non_tensor: bool = False,
|
| ):
|
| """
|
| Args:
|
| model: an nn.Module
|
| inputs: An input argument or a tuple of input arguments used to call model.
|
| After flattening, it has to only consist of tensors.
|
| inference_func: a callable that takes (model, *inputs), calls the
|
| model with inputs, and return outputs. By default it
|
| is ``lambda model, *inputs: model(*inputs)``. Can be override
|
| if you need to call the model differently.
|
| allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
|
| This option will filter out non-tensor objects to make the
|
| model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
|
| used anymore because inputs/outputs cannot be rebuilt from pure tensors.
|
| This is useful when you're only interested in the single trace of
|
| execution (e.g. for flop count), but not interested in
|
| generalizing the traced graph to new inputs.
|
| """
|
| super().__init__()
|
| if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
| model = model.module
|
| self.model = model
|
| if not isinstance(inputs, tuple):
|
| inputs = (inputs,)
|
| self.inputs = inputs
|
| self.allow_non_tensor = allow_non_tensor
|
|
|
| if inference_func is None:
|
| inference_func = lambda model, *inputs: model(*inputs)
|
| self.inference_func = inference_func
|
|
|
| self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
|
|
|
| if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
|
| return
|
| if self.allow_non_tensor:
|
| self.flattened_inputs = tuple(
|
| [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
|
| )
|
| self.inputs_schema = None
|
| else:
|
| for input in self.flattened_inputs:
|
| if not isinstance(input, torch.Tensor):
|
| raise ValueError(
|
| "Inputs for tracing must only contain tensors. "
|
| f"Got a {type(input)} instead."
|
| )
|
|
|
| def forward(self, *args: torch.Tensor):
|
| with torch.no_grad(), patch_builtin_len():
|
| if self.inputs_schema is not None:
|
| inputs_orig_format = self.inputs_schema(args)
|
| else:
|
| if len(args) != len(self.flattened_inputs) or any(
|
| x is not y for x, y in zip(args, self.flattened_inputs)
|
| ):
|
| raise ValueError(
|
| "TracingAdapter does not contain valid inputs_schema."
|
| " So it cannot generalize to other inputs and must be"
|
| " traced with `.flattened_inputs`."
|
| )
|
| inputs_orig_format = self.inputs
|
|
|
| outputs = self.inference_func(self.model, *inputs_orig_format)
|
| flattened_outputs, schema = flatten_to_tuple(outputs)
|
|
|
| flattened_output_tensors = tuple(
|
| [x for x in flattened_outputs if isinstance(x, torch.Tensor)]
|
| )
|
| if len(flattened_output_tensors) < len(flattened_outputs):
|
| if self.allow_non_tensor:
|
| flattened_outputs = flattened_output_tensors
|
| self.outputs_schema = None
|
| else:
|
| raise ValueError(
|
| "Model cannot be traced because some model outputs "
|
| "cannot flatten to tensors."
|
| )
|
| else:
|
| if self.outputs_schema is None:
|
| self.outputs_schema = schema
|
| else:
|
| assert self.outputs_schema == schema, (
|
| "Model should always return outputs with the same "
|
| "structure so it can be traced!"
|
| )
|
| return flattened_outputs
|
|
|
| def _create_wrapper(self, traced_model):
|
| """
|
| Return a function that has an input/output interface the same as the
|
| original model, but it calls the given traced model under the hood.
|
| """
|
|
|
| def forward(*args):
|
| flattened_inputs, _ = flatten_to_tuple(args)
|
| flattened_outputs = traced_model(*flattened_inputs)
|
| return self.outputs_schema(flattened_outputs)
|
|
|
| return forward
|
|
|