File size: 5,771 Bytes
c206440 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | // Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "onnx/onnx_pb.h"
#include "core/graph/basic_types.h"
#include "core/common/status.h"
#include "core/common/logging/logging.h"
namespace onnxruntime {
// Node argument definition, for both input and output,
// including arg name, arg type (contains both type and shape).
//
// Design Question: in my opinion, shape should not be part of type.
// We may align the protobuf design with our operator registry interface,
// which has type specified for each operator, but no shape. Well, shape
// should be inferred with a separate shape inference function given
// input shapes, or input tensor data sometimes.
// With shape as part of type (current protobuf design),
// 1) we'll have to split the "TypeProto" into type and shape in this internal
// representation interface so that it could be easily used when doing type
// inference and matching with operator registry.
// 2) SetType should be always called before SetShape, otherwise, SetShape()
// will fail. Because shape is located in a TypeProto.
// Thoughts?
//
/**
@class NodeArg
Class representing a data type that is input or output for a Node, including the shape if it is a Tensor.
*/
class NodeArg {
public:
/**
Construct a new NodeArg.
@param name The name to use.
@param p_arg_type Optional TypeProto specifying type and shape information.
*/
NodeArg(const std::string& name,
const ONNX_NAMESPACE::TypeProto* p_arg_type);
NodeArg(NodeArg&&) = default;
NodeArg& operator=(NodeArg&& other) = default;
/** Gets the name. */
const std::string& Name() const noexcept;
/** Gets the data type. */
const std::string* Type() const noexcept;
/** Gets the TypeProto
@returns TypeProto if type is set. nullptr otherwise. */
const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept;
/** Gets the shape if NodeArg is for a Tensor.
@returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */
const ONNX_NAMESPACE::TensorShapeProto* Shape() const;
/** Return an indicator.
@returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */
bool HasTensorOrScalarShape() const;
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Sets the shape.
@remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called,
as the shape information is stored as part of TypeProto. */
void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape);
/** Clears shape info.
@remarks If there is a mismatch during shape inferencing that can't be resolved the shape info may be removed. */
void ClearShape();
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD)
/** Override current type from input_type if override_types is set to true, return failure status otherwise.
@param input_tensor_elem_type Tensor element type parsed input_type
@param current_tensor_elem_type Tensor element type parsed from existing type
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type,
int32_t input_tensor_elem_type,
int32_t current_tensor_elem_type,
bool override_types);
/** Validate and merge type [and shape] info from input_type.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);
/** Validate and merge type [and shape] info from node_arg.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);
#endif // !defined(ORT_MINIMAL_BUILD)
/** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
/** Gets a flag indicating whether this NodeArg exists or not.
Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
bool Exists() const noexcept;
friend class Graph;
NodeArg(NodeArgInfo&& node_arg_info);
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg);
void SetType(const std::string* p_type);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
void SetType(const ONNX_NAMESPACE::TypeProto& type_proto);
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// Node arg PType.
const std::string* type_;
// Node arg name, type and shape.
NodeArgInfo node_arg_info_;
// Flag indicates whether <*this> node arg exists or not.
bool exists_;
};
} // namespace onnxruntime
|