File size: 5,771 Bytes
c206440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "onnx/onnx_pb.h"

#include "core/graph/basic_types.h"
#include "core/common/status.h"
#include "core/common/logging/logging.h"

namespace onnxruntime {

// Node argument definition, for both input and output,
// including arg name, arg type (contains both type and shape).
//
// Design Question: in my opinion, shape should not be part of type.
// We may align the protobuf design with our operator registry interface,
// which has type specified for each operator, but no shape. Well, shape
// should be inferred with a separate shape inference function given
// input shapes, or input tensor data sometimes.
// With shape as part of type (current protobuf design),
// 1) we'll have to split the "TypeProto" into type and shape in this internal
// representation interface so that it could be easily used when doing type
// inference and matching with operator registry.
// 2) SetType should be always called before SetShape, otherwise, SetShape()
// will fail. Because shape is located in a TypeProto.
// Thoughts?
//

/**
@class NodeArg
Class representing a data type that is input or output for a Node, including the shape if it is a Tensor.
*/
class NodeArg {
 public:
  /**
  Construct a new NodeArg.
  @param name The name to use.
  @param p_arg_type Optional TypeProto specifying type and shape information.
  */
  NodeArg(const std::string& name,
          const ONNX_NAMESPACE::TypeProto* p_arg_type);

  NodeArg(NodeArg&&) = default;
  NodeArg& operator=(NodeArg&& other) = default;

  /** Gets the name. */
  const std::string& Name() const noexcept;

  /** Gets the data type. */
  const std::string* Type() const noexcept;

  /** Gets the TypeProto
  @returns TypeProto if type is set. nullptr otherwise. */
  const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept;

  /** Gets the shape if NodeArg is for a Tensor.
  @returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */
  const ONNX_NAMESPACE::TensorShapeProto* Shape() const;

  /** Return an indicator.
  @returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */
  bool HasTensorOrScalarShape() const;

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

  /** Sets the shape.
  @remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called,
  as the shape information is stored as part of TypeProto. */
  void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape);

  /** Clears shape info.
  @remarks If there is a mismatch during shape inferencing that can't be resolved the shape info may be removed. */
  void ClearShape();

#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD)

  /** Override current type from input_type if override_types is set to true, return failure status otherwise.
  @param input_tensor_elem_type Tensor element type parsed input_type
  @param current_tensor_elem_type Tensor element type parsed from existing type
  @param override_types If true, resolve the two inputs or two outputs type when different
  @returns Success unless there is existing type or shape info that can't be successfully updated. */
  common::Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type,
                                     int32_t input_tensor_elem_type,
                                     int32_t current_tensor_elem_type,
                                     bool override_types);

  /** Validate and merge type [and shape] info from input_type.
  @param strict If true, the shape update will fail if there are incompatible values.
                If false, will be lenient and merge only shape info that can be validly processed.
  @param override_types If true, resolve the two inputs or two outputs type when different
  @returns Success unless there is existing type or shape info that can't be successfully updated. */
  common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);

  /** Validate and merge type [and shape] info from node_arg.
  @param strict If true, the shape update will fail if there are incompatible values.
                If false, will be lenient and merge only shape info that can be validly processed.
  @param override_types If true, resolve the two inputs or two outputs type when different
  @returns Success unless there is existing type or shape info that can't be successfully updated. */
  common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);

#endif  // !defined(ORT_MINIMAL_BUILD)

  /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
  const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }

  /** Gets a flag indicating whether this NodeArg exists or not.
  Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
  bool Exists() const noexcept;

  friend class Graph;

  NodeArg(NodeArgInfo&& node_arg_info);

 private:
  ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg);
  void SetType(const std::string* p_type);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
  void SetType(const ONNX_NAMESPACE::TypeProto& type_proto);
#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

  // Node arg PType.
  const std::string* type_;

  // Node arg name, type and shape.
  NodeArgInfo node_arg_info_;

  // Flag indicates whether <*this> node arg exists or not.
  bool exists_;
};
}  // namespace onnxruntime