xiaoanyu123 commited on
Commit
c3b656b
·
verified ·
1 Parent(s): 08157a5

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/onnx/__pycache__/checker.cpython-310.pyc +0 -0
  2. pythonProject/.venv/Lib/site-packages/onnx/__pycache__/compose.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/onnx/__pycache__/external_data_helper.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/onnx/__pycache__/gen_proto.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/onnx/__pycache__/helper.cpython-310.pyc +0 -0
  6. pythonProject/.venv/Lib/site-packages/onnx/__pycache__/hub.cpython-310.pyc +0 -0
  7. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/split_17_18.h +38 -0
  8. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/sum_8_7.h +41 -0
  9. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/topk_9_10.h +45 -0
  10. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/transformers.h +85 -0
  11. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/type_restriction.h +61 -0
  12. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_6_7.h +49 -0
  13. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_8_9.h +50 -0
  14. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_9_10.h +45 -0
  15. pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_9_8.h +79 -0
  16. pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/tools/torch_lib/__pycache__/generate_prims_signatures.cpython-310.pyc +0 -0
  17. pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/_constants.py +5 -0
  18. pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/_flags.py +53 -0
  19. pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/registration.py +137 -0
  20. pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/tensor_typing.py +74 -0
pythonProject/.venv/Lib/site-packages/onnx/__pycache__/checker.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/__pycache__/compose.cpython-310.pyc ADDED
Binary file (22 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/__pycache__/external_data_helper.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/__pycache__/gen_proto.cpython-310.pyc ADDED
Binary file (6.48 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/__pycache__/helper.cpython-310.pyc ADDED
Binary file (40.8 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/__pycache__/hub.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/split_17_18.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Split in default domain from version 17 to 18
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+
13
+ #include "onnx/version_converter/adapters/adapter.h"
14
+ #include "onnx/version_converter/adapters/transformers.h"
15
+
16
+ namespace ONNX_NAMESPACE {
17
+ namespace version_conversion {
18
+
19
+ class Split_17_18 : public Adapter {
20
+ public:
21
+ explicit Split_17_18() : Adapter("Split", OpSetID(17), OpSetID(18)) {}
22
+
23
+ void adapt_split_17_18(const std::shared_ptr<Graph>&, Node* node) const {
24
+ const auto num_outputs = node->outputs().size();
25
+ node->i_(knum_outputs, num_outputs);
26
+ }
27
+
28
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
29
+ // if node does not have neither 'num_outputs' attribute nor 'split' input
30
+ if (!node->hasAttribute(knum_outputs) && node->inputs().size() != 2) {
31
+ adapt_split_17_18(graph, node);
32
+ }
33
+ return node;
34
+ }
35
+ };
36
+
37
+ } // namespace version_conversion
38
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/sum_8_7.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Sum in default domain from version 8 to 7
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <vector>
13
+
14
+ #include "onnx/version_converter/adapters/adapter.h"
15
+
16
+ namespace ONNX_NAMESPACE {
17
+ namespace version_conversion {
18
+
19
+ class Sum_8_7 final : public Adapter {
20
+ public:
21
+ explicit Sum_8_7() : Adapter("Sum", OpSetID(8), OpSetID(7)) {}
22
+
23
+ void adapt_sum_8_7(const std::shared_ptr<Graph>&, Node* node) const {
24
+ // Throw an exception if any broadcasting occurs
25
+ const ArrayRef<Value*>& inputs = node->inputs();
26
+ // Determine if inputs are of different sizes
27
+ for (int i = 1; i < (int)inputs.size(); i++) {
28
+ std::vector<Dimension> A_sizes = inputs[i - 1]->sizes();
29
+ std::vector<Dimension> B_sizes = inputs[i]->sizes();
30
+ assert_numpy_multibroadcastable(A_sizes, B_sizes);
31
+ }
32
+ }
33
+
34
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
35
+ adapt_sum_8_7(graph, node);
36
+ return node;
37
+ }
38
+ };
39
+
40
+ } // namespace version_conversion
41
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/topk_9_10.h ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for TopK in default domain from version 9 to 10
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <vector>
13
+
14
+ #include "onnx/version_converter/adapters/adapter.h"
15
+
16
+ namespace ONNX_NAMESPACE {
17
+ namespace version_conversion {
18
+
19
+ class TopK_9_10 final : public Adapter {
20
+ public:
21
+ explicit TopK_9_10() : Adapter("TopK", OpSetID(9), OpSetID(10)) {}
22
+
23
+ void adapt_topk_9_10(const std::shared_ptr<Graph>& graph, Node* node) const {
24
+ Tensor t;
25
+ t.elem_type() = TensorProto_DataType_INT64;
26
+ t.sizes() = std::vector<int64_t>{1};
27
+ auto& data = t.int64s();
28
+ data.emplace_back(node->i(kk));
29
+
30
+ Node* constant = graph->create(kConstant);
31
+ constant->insertBefore(node);
32
+ constant->t_(kvalue, t);
33
+ node->addInput(constant->output());
34
+
35
+ node->removeAttribute(kk);
36
+ }
37
+
38
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
39
+ adapt_topk_9_10(graph, node);
40
+ return node;
41
+ }
42
+ };
43
+
44
+ } // namespace version_conversion
45
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/transformers.h ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ #pragma once
8
+
9
+ #include <cinttypes>
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ #include "onnx/common/interned_strings.h"
14
+ #include "onnx/version_converter/adapters/adapter.h"
15
+
16
+ // Node transformers commonly used in version-adapters:
17
+
18
+ // Capture context by copying values; the graph is unused by these transformers.
19
+
20
+ #define NODE_TRANSFORMER(node) [=](const std::shared_ptr<Graph>&, Node* node)
21
+
22
+ namespace ONNX_NAMESPACE {
23
+ namespace version_conversion {
24
+
25
+ inline NodeTransformerFunction RemoveAttribute(Symbol attr) {
26
+ return NODE_TRANSFORMER(node) {
27
+ if (node->hasAttribute(attr)) {
28
+ node->removeAttribute(attr);
29
+ }
30
+ return node;
31
+ };
32
+ }
33
+
34
+ inline NodeTransformerFunction RemoveAttribute(Symbol attr, int64_t value) {
35
+ return NODE_TRANSFORMER(node) {
36
+ if (node->hasAttribute(attr)) {
37
+ ONNX_ASSERTM(node->i(attr) == value, "Attribute %s must have value %" PRId64, attr.toString(), value)
38
+ node->removeAttribute(attr);
39
+ }
40
+ return node;
41
+ };
42
+ }
43
+
44
+ inline NodeTransformerFunction RemoveAttributeNotEq(Symbol attr, int64_t value) {
45
+ return NODE_TRANSFORMER(node) {
46
+ if (node->hasAttribute(attr)) {
47
+ ONNX_ASSERTM(node->i(attr) != value, "Attribute %s must not have value %" PRId64, attr.toString(), value)
48
+ node->removeAttribute(attr);
49
+ }
50
+ return node;
51
+ };
52
+ }
53
+
54
+ inline NodeTransformerFunction SetAttribute(Symbol attr, int64_t value) {
55
+ return NODE_TRANSFORMER(node) {
56
+ node->i_(attr, value);
57
+ return node;
58
+ };
59
+ }
60
+
61
+ inline NodeTransformerFunction SetAttribute(Symbol attr, const std::string& value) {
62
+ return NODE_TRANSFORMER(node) {
63
+ node->s_(attr, value);
64
+ return node;
65
+ };
66
+ }
67
+
68
+ inline NodeTransformerFunction SetAttribute(Symbol attr, const std::vector<int64_t>& value) {
69
+ return NODE_TRANSFORMER(node) {
70
+ node->is_(attr, std::vector<int64_t>(value));
71
+ return node;
72
+ };
73
+ }
74
+
75
+ inline NodeTransformerFunction SetAttributeIfAbsent(Symbol attr, int64_t value) {
76
+ return NODE_TRANSFORMER(node) {
77
+ if (!node->hasAttribute(attr)) {
78
+ node->i_(attr, value);
79
+ }
80
+ return node;
81
+ };
82
+ }
83
+
84
+ } // namespace version_conversion
85
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/type_restriction.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Add in default domain from version 6 to 5
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <string>
13
+ #include <vector>
14
+
15
+ #include "onnx/version_converter/adapters/adapter.h"
16
+
17
+ namespace ONNX_NAMESPACE {
18
+ namespace version_conversion {
19
+
20
+ class TypeRestriction : public Adapter {
21
+ public:
22
+ explicit TypeRestriction(
23
+ const std::string& op_name,
24
+ const OpSetID& initial,
25
+ const OpSetID& target,
26
+ const std::vector<TensorProto_DataType>& unallowed_types)
27
+ : Adapter(op_name, initial, target), unallowed_types_(unallowed_types) {}
28
+
29
+ void adapt_type_restriction(const std::shared_ptr<Graph>&, const Node* node) const {
30
+ // Since consumed_inputs is optional, no need to add it (as in batchnorm)
31
+ // Iterate over all inputs and outputs
32
+ for (const Value* input : node->inputs()) {
33
+ isUnallowed(input);
34
+ }
35
+ for (const Value* output : node->outputs()) {
36
+ isUnallowed(output);
37
+ }
38
+ }
39
+
40
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
41
+ adapt_type_restriction(graph, node);
42
+ return node;
43
+ }
44
+
45
+ private:
46
+ std::vector<TensorProto_DataType> unallowed_types_;
47
+
48
+ void isUnallowed(const Value* val) const {
49
+ ONNX_ASSERTM(
50
+ std::find(std::begin(unallowed_types_), std::end(unallowed_types_), val->elemType()) ==
51
+ std::end(unallowed_types_),
52
+ "DataType (%d) of Input or Output"
53
+ " of operator '%s' is unallowed for Opset Version %d.",
54
+ val->elemType(),
55
+ name().c_str(),
56
+ target_version().version())
57
+ }
58
+ };
59
+
60
+ } // namespace version_conversion
61
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_6_7.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Upsample in default domain from version 6 to 7
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <utility>
13
+ #include <vector>
14
+
15
+ #include "onnx/version_converter/adapters/adapter.h"
16
+
17
+ namespace ONNX_NAMESPACE {
18
+ namespace version_conversion {
19
+
20
+ struct Upsample_6_7 final : public Adapter {
21
+ explicit Upsample_6_7() : Adapter("Upsample", OpSetID(6), OpSetID(7)) {}
22
+
23
+ void adapt_upsample_6_7(const std::shared_ptr<Graph>&, Node* node) const {
24
+ Symbol width_scale_symbol = Symbol("width_scale");
25
+ Symbol height_scale_symbol = Symbol("height_scale");
26
+ ONNX_ASSERTM(
27
+ node->hasAttribute(width_scale_symbol) && node->hasAttribute(height_scale_symbol),
28
+ "Upsample in opset 1 needs to have width_scale and height_scale attributes")
29
+
30
+ auto width_scale = node->f(width_scale_symbol);
31
+ auto height_scale = node->f(height_scale_symbol);
32
+
33
+ auto input_shape = node->inputs()[0]->sizes();
34
+ ONNX_ASSERTM(input_shape.size() == 4, "Upsample in opset 1 supports only 4D input tensor")
35
+ std::vector<double> scales = {1.0, 1.0, height_scale, width_scale};
36
+
37
+ node->fs_(kscales, std::move(scales));
38
+ node->removeAttribute(width_scale_symbol);
39
+ node->removeAttribute(height_scale_symbol);
40
+ }
41
+
42
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
43
+ adapt_upsample_6_7(graph, node);
44
+ return node;
45
+ }
46
+ };
47
+
48
+ } // namespace version_conversion
49
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_8_9.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Upsample in default domain from version 8 to 9
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <vector>
13
+
14
+ #include "onnx/version_converter/adapters/adapter.h"
15
+
16
+ namespace ONNX_NAMESPACE {
17
+ namespace version_conversion {
18
+
19
+ struct Upsample_8_9 final : public Adapter {
20
+ explicit Upsample_8_9() : Adapter("Upsample", OpSetID(8), OpSetID(9)) {}
21
+
22
+ void adapt_upsample_8_9(const std::shared_ptr<Graph>& graph, Node* node) const {
23
+ Symbol input_dirs = Symbol("scales");
24
+ int dim = (int)(node->fs(kscales).size());
25
+ Tensor t;
26
+ t.elem_type() = TensorProto_DataType_FLOAT;
27
+ t.sizes() = std::vector<int64_t>{dim};
28
+ auto& data = t.floats();
29
+
30
+ if (node->hasAttribute(input_dirs)) {
31
+ for (double scale : node->fs(kscales)) {
32
+ data.emplace_back((float)scale);
33
+ }
34
+
35
+ Node* constant = graph->create(kConstant);
36
+ constant->insertBefore(node);
37
+ constant->t_(kvalue, t);
38
+ node->addInput(constant->output());
39
+ node->removeAttribute(kscales);
40
+ }
41
+ }
42
+
43
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
44
+ adapt_upsample_8_9(graph, node);
45
+ return node;
46
+ }
47
+ };
48
+
49
+ } // namespace version_conversion
50
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_9_10.h ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Upsample in default domain from version 9 to 10
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <string>
13
+
14
+ #include "onnx/common/ir.h"
15
+ #include "onnx/version_converter/adapters/adapter.h"
16
+ namespace ONNX_NAMESPACE {
17
+ namespace version_conversion {
18
+
19
+ class Upsample_9_10 final : public Adapter {
20
+ public:
21
+ explicit Upsample_9_10() : Adapter("Upsample", OpSetID(9), OpSetID(10)) {}
22
+
23
+ Node* adapt_upsample_9_10(const std::shared_ptr<Graph>& graph, Node* node) const {
24
+ std::string mode = node->hasAttribute(kmode) ? node->s(kmode) : "nearest";
25
+
26
+ // Replace the node with an equivalent Resize node
27
+ Node* resize = graph->create(kResize);
28
+ resize->s_(kmode, mode);
29
+ resize->addInput(node->inputs()[0]);
30
+ resize->addInput(node->inputs()[1]);
31
+ node->replaceAllUsesWith(resize);
32
+
33
+ resize->insertBefore(node);
34
+ node->destroy();
35
+
36
+ return resize;
37
+ }
38
+
39
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
40
+ return adapt_upsample_9_10(graph, node);
41
+ }
42
+ };
43
+
44
+ } // namespace version_conversion
45
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnx/version_converter/adapters/upsample_9_8.h ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ // Adapter for Upsample in default domain from version 9 to 8
8
+
9
+ #pragma once
10
+
11
+ #include <memory>
12
+ #include <string>
13
+ #include <vector>
14
+
15
+ #include "onnx/defs/tensor_proto_util.h"
16
+ #include "onnx/defs/tensor_util.h"
17
+ #include "onnx/version_converter/adapters/adapter.h"
18
+
19
+ namespace ONNX_NAMESPACE {
20
+ namespace version_conversion {
21
+
22
+ struct Upsample_9_8 final : public Adapter {
23
+ explicit Upsample_9_8() : Adapter("Upsample", OpSetID(9), OpSetID(8)) {}
24
+
25
+ void adapt_upsample_9_8(const std::shared_ptr<Graph>& graph, Node* node) const {
26
+ const ArrayRef<Value*>& inputs = node->inputs();
27
+ const std::vector<Tensor>& initializers = graph->initializers();
28
+
29
+ ONNX_ASSERTM(inputs.size() == 2, "Upsample in opset 9 needs to have 2 inputs.")
30
+ std::string scale_input_name = node->inputs()[1]->uniqueName();
31
+
32
+ for (const auto& initializer : initializers) {
33
+ if (initializer.name() == inputs[1]->uniqueName()) {
34
+ std::vector<float> value = ParseData<float>(&initializer);
35
+ std::vector<double> d_values;
36
+ d_values.reserve(value.size());
37
+ for (float j : value) {
38
+ d_values.push_back(static_cast<double>(j));
39
+ }
40
+ node->fs_(kscales, std::move(d_values));
41
+
42
+ node->removeInput(1);
43
+ graph->eraseInitializer(initializer.name());
44
+ for (size_t j = 0; j < graph->inputs().size(); j++) {
45
+ if (graph->inputs()[j]->uniqueName() == scale_input_name) {
46
+ graph->eraseInput(j);
47
+ break;
48
+ }
49
+ }
50
+ return;
51
+ }
52
+ }
53
+
54
+ for (Node* op : graph->nodes()) {
55
+ if (op->kind() == kConstant && op->outputs()[0]->uniqueName() == scale_input_name) {
56
+ std::vector<float> value = ParseData<float>(&op->t(kvalue));
57
+ std::vector<double> d_values;
58
+ d_values.reserve(value.size());
59
+ for (float j : value) {
60
+ d_values.push_back(static_cast<double>(j));
61
+ }
62
+ node->fs_(kscales, std::move(d_values));
63
+ node->removeInput(1);
64
+ op->destroy();
65
+ return;
66
+ }
67
+ }
68
+
69
+ ONNX_ASSERTM(false, "Unsuppported conversion due to unavailable input: scale")
70
+ }
71
+
72
+ Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
73
+ adapt_upsample_9_8(graph, node);
74
+ return node;
75
+ }
76
+ };
77
+
78
+ } // namespace version_conversion
79
+ } // namespace ONNX_NAMESPACE
pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/tools/torch_lib/__pycache__/generate_prims_signatures.cpython-310.pyc ADDED
Binary file (8.83 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/_constants.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Shared constants for the library."""
4
+
5
+ DOMAIN = "pkg.onnxscript.torch_lib"
pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/_flags.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Experimental flags.
4
+
5
+ NOTE: These flags are experimental only. Any flag here can be removed at any
6
+ time without notice.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _load_boolean_flag(
16
+ name: str,
17
+ *,
18
+ this_will: str,
19
+ deprecated: bool = False,
20
+ default: bool = False,
21
+ ) -> bool:
22
+ """Load a boolean flag from environment variable.
23
+
24
+ Args:
25
+ name: The name of the environment variable.
26
+ this_will: A string that describes what this flag will do.
27
+ deprecated: Whether this flag is deprecated.
28
+ default: The default value if envvar not defined.
29
+ """
30
+ undefined = os.getenv(name) is None
31
+ state = os.getenv(name) == "1"
32
+ if state:
33
+ if deprecated:
34
+ logger.error(
35
+ "Experimental flag %s is deprecated. Please remove it from your environment.",
36
+ name,
37
+ )
38
+ else:
39
+ logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will)
40
+ if undefined:
41
+ state = default
42
+ return state
43
+
44
+
45
+ EXPERIMENTAL_INITIALIZERS_AS_INPUTS: bool = _load_boolean_flag(
46
+ "TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
47
+ this_will="make initializers as inputs to the model graph",
48
+ )
49
+ EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag(
50
+ "TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
51
+ this_will="trace all traceable functions to fold if branches and collapse constant expressions",
52
+ default=True,
53
+ )
pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/registration.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Registry for aten functions."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import re
8
+ from typing import Any, Callable, Generator, Optional
9
+
10
+ import onnxscript
11
+ from onnxscript.function_libs.torch_lib import _constants
12
+
13
+ # Regex that will match "<namespace>::<op_name>[.<overload>]"
14
+ _QUALIFIED_OPERATOR_NAME_REGEX = re.compile(
15
+ r"^(?P<namespace>[a-zA-Z0-9_]+)::(?P<name>[a-zA-Z0-9_]+)(?P<overload>\.[a-zA-Z0-9._]+)?$"
16
+ )
17
+
18
+
19
+ class OverloadedFunction:
20
+ """Overloaded function.
21
+
22
+ Attributes:
23
+ name: Name of the op. E.g. "aten::add".
24
+ overloads: Overloads function.
25
+ privates: Private functions not exposed to users.
26
+ complex: Support complex functions.
27
+ """
28
+
29
+ def __init__(self, name: str):
30
+ self.name = name
31
+ self.overloads: list[Any] = []
32
+ self.privates: list[Any] = []
33
+ self.complex: list[Any] = []
34
+
35
+
36
+ class Registry:
37
+ """Registry for aten functions."""
38
+
39
+ def __init__(self):
40
+ self._registry: dict[str, OverloadedFunction] = {}
41
+
42
+ def register(
43
+ self, func: Any, name: str, *, private: bool = False, complex: bool = False
44
+ ) -> None:
45
+ """Register a function."""
46
+
47
+ if private:
48
+ self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
49
+ elif complex:
50
+ self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
51
+ else:
52
+ self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)
53
+
54
+ def __getitem__(self, name):
55
+ return self._registry[name]
56
+
57
+ def __contains__(self, name):
58
+ return name in self._registry
59
+
60
+ def __iter__(self):
61
+ return iter(self._registry)
62
+
63
+ def __repr__(self):
64
+ return repr(self._registry)
65
+
66
+ def items(self) -> Generator[tuple[str, OverloadedFunction], None, None]:
67
+ yield from self._registry.items()
68
+
69
+ def values(self) -> Generator[OverloadedFunction, None, None]:
70
+ yield from self._registry.values()
71
+
72
+
73
+ # Default registry
74
+ default_registry = Registry()
75
+
76
+
77
+ def _check_and_normalize_names(name: str | tuple[str, ...]) -> tuple[str, ...]:
78
+ names: tuple[str, ...]
79
+
80
+ if isinstance(name, str):
81
+ names = (name,)
82
+ else:
83
+ names = name
84
+ if not isinstance(names, tuple):
85
+ raise TypeError(f"Name must be a string or a tuple of strings, got {name}")
86
+ for name_ in names:
87
+ if name_.endswith(".default") or not _QUALIFIED_OPERATOR_NAME_REGEX.fullmatch(name_):
88
+ raise ValueError(
89
+ f"Invalid name '{name_}'. Must be in the form 'namespace::name' for default overloads "
90
+ "or 'namespace::name.overload' for other overloads."
91
+ )
92
+
93
+ return names
94
+
95
+
96
+ def torch_op(
97
+ name: str | tuple[str, ...],
98
+ *,
99
+ registry: Optional[Registry] = None,
100
+ trace_only: bool = False,
101
+ private: bool = False,
102
+ complex: bool = False,
103
+ ) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
104
+ """Register a torch op.
105
+
106
+ Args:
107
+ name: Qualified ATen name of the function. E.g. "aten::relu", "aten::add.Tensor".
108
+ Or a tuple of names e.g. ("aten::add.Scalar", "aten::add.Tensor").
109
+ Default overloads should be specified by omitting the overload part,
110
+ i.e. "aten::relu" instead of "aten::relu.default".
111
+ registry: Registry to register the function to. If None, the default registry is used.
112
+ trace_only: Whether the function should only be traced and not compiled.
113
+ private: Whether the function is private (not directly exposed). It should
114
+ be true for all functions with names starting with "_".
115
+ complex: Whether the function expects complex-valued inputs.
116
+ """
117
+ if registry is None:
118
+ registry = default_registry
119
+
120
+ def wrapper(
121
+ func: Callable,
122
+ ) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction:
123
+ # Compile the function
124
+ custom_opset = onnxscript.values.Opset(domain=_constants.DOMAIN, version=1)
125
+
126
+ processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction
127
+ if trace_only:
128
+ processed_func = onnxscript.values.TracedOnnxFunction(custom_opset, func)
129
+ else:
130
+ processed_func = onnxscript.script(opset=custom_opset)(func)
131
+
132
+ assert registry is not None
133
+ for name_ in _check_and_normalize_names(name):
134
+ registry.register(processed_func, name_, private=private, complex=complex)
135
+ return processed_func
136
+
137
+ return wrapper
pythonProject/.venv/Lib/site-packages/onnxscript/function_libs/torch_lib/tensor_typing.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ """Typings for function definitions."""
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TypeVar, Union
11
+
12
+ from onnxscript import (
13
+ BFLOAT16,
14
+ BOOL,
15
+ COMPLEX64,
16
+ COMPLEX128,
17
+ DOUBLE,
18
+ FLOAT,
19
+ FLOAT16,
20
+ INT8,
21
+ INT16,
22
+ INT32,
23
+ INT64,
24
+ STRING,
25
+ UINT8,
26
+ )
27
+
28
+ # NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
29
+ # More detail can be found: https://pytorch.org/docs/stable/tensors.html
30
+
31
+ _TensorType = Union[
32
+ BFLOAT16,
33
+ BOOL,
34
+ COMPLEX64,
35
+ COMPLEX128,
36
+ DOUBLE,
37
+ FLOAT,
38
+ FLOAT16,
39
+ INT8,
40
+ INT16,
41
+ INT32,
42
+ INT64,
43
+ UINT8,
44
+ ]
45
+ _FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
46
+ IntType = Union[INT8, INT16, INT32, INT64]
47
+ RealType = Union[
48
+ BFLOAT16,
49
+ FLOAT16,
50
+ FLOAT,
51
+ DOUBLE,
52
+ INT8,
53
+ INT16,
54
+ INT32,
55
+ INT64,
56
+ ]
57
+
58
+ TTensor = TypeVar("TTensor", bound=_TensorType)
59
+ # Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
60
+ # but do not constrain the type to be the same as the other inputs/outputs
61
+ TTensor2 = TypeVar("TTensor2", bound=_TensorType)
62
+ TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
63
+ TFloat = TypeVar("TFloat", bound=_FloatType)
64
+ TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8])
65
+ TInt = TypeVar("TInt", bound=IntType)
66
+ TReal = TypeVar("TReal", bound=RealType)
67
+ TRealUnlessInt16OrInt8 = TypeVar(
68
+ "TRealUnlessInt16OrInt8", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64]
69
+ )
70
+ TRealUnlessFloat16OrInt8 = TypeVar(
71
+ "TRealUnlessFloat16OrInt8", bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64]
72
+ )
73
+ TRealOrUInt8 = TypeVar("TRealOrUInt8", bound=Union[RealType, UINT8])
74
+ TFloatHighPrecision = TypeVar("TFloatHighPrecision", bound=Union[FLOAT, DOUBLE])