| // Copyright (c) Microsoft Corporation. All rights reserved. | |
| // Licensed under the MIT License. | |
| namespace onnxruntime { | |
| // Node argument definition, for both input and output, | |
| // including arg name, arg type (contains both type and shape). | |
| // | |
| // Design Question: in my opinion, shape should not be part of type. | |
| // We may align the protobuf design with our operator registry interface, | |
| // which has type specified for each operator, but no shape. Well, shape | |
| // should be inferred with a separate shape inference function given | |
| // input shapes, or input tensor data sometimes. | |
| // With shape as part of type (current protobuf design), | |
| // 1) we'll have to split the "TypeProto" into type and shape in this internal | |
| // representation interface so that it could be easily used when doing type | |
| // inference and matching with operator registry. | |
| // 2) SetType should be always called before SetShape, otherwise, SetShape() | |
| // will fail. Because shape is located in a TypeProto. | |
| // Thoughts? | |
| // | |
| /** | |
| @class NodeArg | |
| Class representing a data type that is input or output for a Node, including the shape if it is a Tensor. | |
| */ | |
| class NodeArg { | |
| public: | |
| /** | |
| Construct a new NodeArg. | |
| @param name The name to use. | |
| @param p_arg_type Optional TypeProto specifying type and shape information. | |
| */ | |
| NodeArg(const std::string& name, | |
| const ONNX_NAMESPACE::TypeProto* p_arg_type); | |
| NodeArg(NodeArg&&) = default; | |
| NodeArg& operator=(NodeArg&& other) = default; | |
| /** Gets the name. */ | |
| const std::string& Name() const noexcept; | |
| /** Gets the data type. */ | |
| const std::string* Type() const noexcept; | |
| /** Gets the TypeProto | |
| @returns TypeProto if type is set. nullptr otherwise. */ | |
| const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept; | |
| /** Gets the shape if NodeArg is for a Tensor. | |
| @returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */ | |
| const ONNX_NAMESPACE::TensorShapeProto* Shape() const; | |
| /** Return an indicator. | |
| @returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */ | |
| bool HasTensorOrScalarShape() const; | |
| /** Sets the shape. | |
| @remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called, | |
| as the shape information is stored as part of TypeProto. */ | |
| void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape); | |
| /** Clears shape info. | |
| @remarks If there is a mismatch during shape inferencing that can't be resolved the shape info may be removed. */ | |
| void ClearShape(); | |
| /** Override current type from input_type if override_types is set to true, return failure status otherwise. | |
| @param input_tensor_elem_type Tensor element type parsed input_type | |
| @param current_tensor_elem_type Tensor element type parsed from existing type | |
| @param override_types If true, resolve the two inputs or two outputs type when different | |
| @returns Success unless there is existing type or shape info that can't be successfully updated. */ | |
| common::Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type, | |
| int32_t input_tensor_elem_type, | |
| int32_t current_tensor_elem_type, | |
| bool override_types); | |
| /** Validate and merge type [and shape] info from input_type. | |
| @param strict If true, the shape update will fail if there are incompatible values. | |
| If false, will be lenient and merge only shape info that can be validly processed. | |
| @param override_types If true, resolve the two inputs or two outputs type when different | |
| @returns Success unless there is existing type or shape info that can't be successfully updated. */ | |
| common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger); | |
| /** Validate and merge type [and shape] info from node_arg. | |
| @param strict If true, the shape update will fail if there are incompatible values. | |
| If false, will be lenient and merge only shape info that can be validly processed. | |
| @param override_types If true, resolve the two inputs or two outputs type when different | |
| @returns Success unless there is existing type or shape info that can't be successfully updated. */ | |
| common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger); | |
| /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */ | |
| const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } | |
| /** Gets a flag indicating whether this NodeArg exists or not. | |
| Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ | |
| bool Exists() const noexcept; | |
| friend class Graph; | |
| NodeArg(NodeArgInfo&& node_arg_info); | |
| private: | |
| ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg); | |
| void SetType(const std::string* p_type); | |
| void SetType(const ONNX_NAMESPACE::TypeProto& type_proto); | |
| // Node arg PType. | |
| const std::string* type_; | |
| // Node arg name, type and shape. | |
| NodeArgInfo node_arg_info_; | |
| // Flag indicates whether <*this> node arg exists or not. | |
| bool exists_; | |
| }; | |
| } // namespace onnxruntime | |