xiaoanyu123 commited on
Commit
08157a5
·
verified ·
1 Parent(s): e914c7d

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/METADATA +136 -0
  2. pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/RECORD +0 -0
  3. pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/WHEEL +5 -0
  4. pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/licenses/LICENSE +202 -0
  5. pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/top_level.txt +1 -0
  6. pythonProject/.venv/Lib/site-packages/onnx/test/cpp/utf8_conversion_test.cc +27 -0
  7. pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_conversion_test_base.cpython-310.pyc +0 -0
  8. pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_downgrade_test.cpython-310.pyc +0 -0
  9. pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_conversion_test_base.py +149 -0
  10. pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_downgrade_test.py +106 -0
  11. pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_upgrade_test.py +1964 -0
  12. pythonProject/.venv/Lib/site-packages/onnxscript/converter.py +1462 -0
  13. pythonProject/.venv/Lib/site-packages/onnxscript/evaluator.py +619 -0
  14. pythonProject/.venv/Lib/site-packages/onnxscript/irbuilder.py +561 -0
  15. pythonProject/.venv/Lib/site-packages/onnxscript/main.py +176 -0
  16. pythonProject/.venv/Lib/site-packages/onnxscript/onnx_types.py +237 -0
  17. pythonProject/.venv/Lib/site-packages/onnxscript/py.typed +1 -0
  18. pythonProject/.venv/Lib/site-packages/onnxscript/sourceinfo.py +59 -0
  19. pythonProject/.venv/Lib/site-packages/onnxscript/tensor.py +225 -0
  20. pythonProject/.venv/Lib/site-packages/onnxscript/type_annotation.py +288 -0
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/METADATA ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: onnx
3
+ Version: 1.19.0
4
+ Summary: Open Neural Network Exchange
5
+ Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
+ License: Apache License v2.0
7
+ Project-URL: Homepage, https://onnx.ai/
8
+ Project-URL: Repository, https://github.com/onnx/onnx
9
+ Classifier: Programming Language :: Python :: 3
10
+ Requires-Python: >=3.9
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: numpy>=1.22
14
+ Requires-Dist: protobuf>=4.25.1
15
+ Requires-Dist: typing_extensions>=4.7.1
16
+ Requires-Dist: ml_dtypes
17
+ Provides-Extra: reference
18
+ Requires-Dist: Pillow; extra == "reference"
19
+ Dynamic: license-file
20
+
21
+ <!--
22
+ Copyright (c) ONNX Project Contributors
23
+
24
+ SPDX-License-Identifier: Apache-2.0
25
+ -->
26
+
27
+ <p align="center"><img width="40%" src="https://github.com/onnx/onnx/raw/main/docs/onnx-horizontal-color.png" /></p>
28
+
29
+ [![PyPI - Version](https://img.shields.io/pypi/v/onnx.svg)](https://pypi.org/project/onnx)
30
+ [![CI](https://github.com/onnx/onnx/actions/workflows/main.yml/badge.svg)](https://github.com/onnx/onnx/actions/workflows/main.yml)
31
+ [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/3313/badge)](https://bestpractices.coreinfrastructure.org/projects/3313)
32
+ [![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/onnx/onnx/badge)](https://api.securityscorecards.dev/projects/github.com/onnx/onnx)
33
+ [![REUSE compliant](https://api.reuse.software/badge/github.com/onnx/onnx)](https://api.reuse.software/info/github.com/onnx/onnx)
34
+ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
35
+ [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
36
+
37
+ [Open Neural Network Exchange (ONNX)](https://onnx.ai) is an open ecosystem that empowers AI developers
38
+ to choose the right tools as their project evolves. ONNX provides an open source format for AI models, both deep learning and traditional ML. It defines an extensible computation graph model, as well as definitions of built-in operators and standard
39
+ data types. Currently we focus on the capabilities needed for inferencing (scoring).
40
+
41
+ ONNX is [widely supported](http://onnx.ai/supported-tools) and can be found in many frameworks, tools, and hardware. Enabling interoperability between different frameworks and streamlining the path from research to production helps increase the speed of innovation in the AI community. We invite the community to join us and further evolve ONNX.
42
+
43
+ # Use ONNX
44
+
45
+ * [Documentation of ONNX Python Package](https://onnx.ai/onnx/)
46
+ * [Tutorials for creating ONNX models](https://github.com/onnx/tutorials)
47
+ * [Pre-trained ONNX models](https://github.com/onnx/models)
48
+
49
+ # Learn about the ONNX spec
50
+
51
+ * [Overview](https://github.com/onnx/onnx/blob/main/docs/Overview.md)
52
+ * [ONNX intermediate representation spec](https://github.com/onnx/onnx/blob/main/docs/IR.md)
53
+ * [Versioning principles of the spec](https://github.com/onnx/onnx/blob/main/docs/Versioning.md)
54
+ * [Operators documentation](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
55
+ * [Operators documentation](https://onnx.ai/onnx/operators/index.html) (latest release)
56
+ * [Python API Overview](https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md)
57
+
58
+ # Programming utilities for working with ONNX Graphs
59
+
60
+ * [Shape and Type Inference](https://github.com/onnx/onnx/blob/main/docs/ShapeInference.md)
61
+ * [Graph Optimization](https://github.com/onnx/optimizer)
62
+ * [Opset Version Conversion](https://github.com/onnx/onnx/blob/main/docs/docsgen/source/api/version_converter.md)
63
+
64
+ # Contribute
65
+
66
+ ONNX is a community project and the open governance model is described [here](https://github.com/onnx/onnx/blob/main/community/readme.md). We encourage you to join the effort and contribute feedback, ideas, and code. You can participate in the [Special Interest Groups](https://github.com/onnx/onnx/blob/main/community/sigs.md) and [Working Groups](https://github.com/onnx/onnx/blob/main/community/working-groups.md) to shape the future of ONNX.
67
+
68
+ Check out our [contribution guide](https://github.com/onnx/onnx/blob/main/CONTRIBUTING.md) to get started.
69
+
70
+ If you think some operator should be added to ONNX specification, please read
71
+ [this document](https://github.com/onnx/onnx/blob/main/docs/AddNewOp.md).
72
+
73
+ # Community meetings
74
+
75
+ The schedules of the regular meetings of the Steering Committee, the working groups and the SIGs can be found [here](https://onnx.ai/calendar)
76
+
77
+ Community Meetups are held at least once a year. Content from previous community meetups are at:
78
+
79
+ * 2020.04.09 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14091402/LF+AI+Day+-ONNX+Community+Virtual+Meetup+-+Silicon+Valley+-+2020+April+9>
80
+ * 2020.10.14 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14092138/LF+AI+Day+-+ONNX+Community+Workshop+-+2020+October+14>
81
+ * 2021.03.24 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14092424/Instructions+for+Event+Hosts+-+LF+AI+Data+Day+-+ONNX+Virtual+Community+Meetup+-+March+2021>
82
+ * 2021.10.21 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14093194/LF+AI+Data+Day+ONNX+Community+Virtual+Meetup+-+October+2021>
83
+ * 2022.06.24 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14093969/ONNX+Community+Day+-+2022+June+24>
84
+ * 2023.06.28 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14094507/ONNX+Community+Day+2023+-+June+28>
85
+
86
+
87
+
88
+ # Discuss
89
+
90
+ We encourage you to open [Issues](https://github.com/onnx/onnx/issues), or use [Slack](https://lfaifoundation.slack.com/) (If you have not joined yet, please use this [link](https://join.slack.com/t/lfaifoundation/shared_invite/zt-o65errpw-gMTbwNr7FnNbVXNVFkmyNA) to join the group) for more real-time discussion.
91
+
92
+ # Follow Us
93
+
94
+ Stay up to date with the latest ONNX news. [[Facebook](https://www.facebook.com/onnxai/)] [[Twitter](https://twitter.com/onnxai)]
95
+
96
+ # Roadmap
97
+
98
+ A roadmap process takes place every year. More details can be found [here](https://github.com/onnx/steering-committee/tree/main/roadmap)
99
+
100
+ # Installation
101
+
102
+ ONNX released packages are published in PyPi.
103
+
104
+ ```sh
105
+ pip install onnx # or pip install onnx[reference] for optional reference implementation dependencies
106
+ ```
107
+
108
+ [ONNX weekly packages](https://pypi.org/project/onnx-weekly/) are published in PyPI to enable experimentation and early testing.
109
+
110
+ Detailed install instructions, including Common Build Options and Common Errors can be found [here](https://github.com/onnx/onnx/blob/main/INSTALL.md)
111
+
112
+ # Testing
113
+
114
+ ONNX uses [pytest](https://docs.pytest.org) as test driver. In order to run tests, you will first need to install `pytest`:
115
+
116
+ ```sh
117
+ pip install pytest
118
+ ```
119
+
120
+ After installing pytest, use the following command to run tests.
121
+
122
+ ```sh
123
+ pytest
124
+ ```
125
+
126
+ # Development
127
+
128
+ Check out the [contributor guide](https://github.com/onnx/onnx/blob/main/CONTRIBUTING.md) for instructions.
129
+
130
+ # License
131
+
132
+ [Apache License v2.0](LICENSE)
133
+
134
+ # Code of Conduct
135
+
136
+ [ONNX Open Source Code of Conduct](https://onnx.ai/codeofconduct.html)
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/RECORD ADDED
The diff for this file is too large to render. See raw diff
 
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (1.19.0)
3
+ Root-Is-Purelib: false
4
+ Tag: cp310-cp310-win_amd64
5
+
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/licenses/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ onnx
pythonProject/.venv/Lib/site-packages/onnx/test/cpp/utf8_conversion_test.cc ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) ONNX Project Contributors
2
+
3
+ /*
4
+ * SPDX-License-Identifier: Apache-2.0
5
+ */
6
+
7
+ #ifdef _WIN32
8
+ #include <string>
9
+
10
+ #include "gtest/gtest.h"
11
+ #include "onnx/common/path.h"
12
+ namespace ONNX_NAMESPACE::Test {
13
+
14
+ TEST(UTF8Test, WideStringConvertion) {
15
+ std::string utf8_str(u8"世界,你好!");
16
+ EXPECT_EQ(ONNX_NAMESPACE::wstring_to_utf8str(ONNX_NAMESPACE::utf8str_to_wstring(utf8_str)), utf8_str);
17
+ }
18
+
19
+ TEST(UTF8Test, TryConvertUTF8) {
20
+ std::string utf8_str(u8"世界,你好!");
21
+ auto wstr = ONNX_NAMESPACE::utf8str_to_wstring(utf8_str);
22
+ auto wstr2 = ONNX_NAMESPACE::utf8str_to_wstring(
23
+ std::string(reinterpret_cast<const char*>(wstr.c_str()), sizeof(std::wstring::value_type) * wstr.size()), true);
24
+ EXPECT_EQ(wstr, wstr2);
25
+ }
26
+ } // namespace ONNX_NAMESPACE::Test
27
+ #endif
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_conversion_test_base.cpython-310.pyc ADDED
Binary file (6.24 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_downgrade_test.cpython-310.pyc ADDED
Binary file (3.28 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_conversion_test_base.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ONNX Project Contributors
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from __future__ import annotations
5
+
6
+ import string
7
+ import unittest
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ import onnx
11
+ from onnx import TensorProto, ValueInfoProto, helper, shape_inference, version_converter
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Sequence
15
+
16
+ LATEST_OPSET = onnx.defs.onnx_opset_version()
17
+
18
+
19
+ class TestAutomaticConversion(unittest.TestCase):
20
+ def _test_model_conversion(
21
+ self, to_opset: int, model: str | onnx.ModelProto
22
+ ) -> None:
23
+ if isinstance(model, str):
24
+ model = onnx.parser.parse_model(model)
25
+ onnx.checker.check_model(model)
26
+ shape_inference.infer_shapes(model, strict_mode=True)
27
+
28
+ converted = version_converter.convert_version(model, to_opset)
29
+ onnx.checker.check_model(converted)
30
+ shape_inference.infer_shapes(converted, strict_mode=True)
31
+
32
+ def _test_model_conversion_fails(
33
+ self, to_opset: int, model: str | onnx.ModelProto
34
+ ) -> None:
35
+ if isinstance(model, str):
36
+ model = onnx.parser.parse_model(model)
37
+ onnx.checker.check_model(model)
38
+ shape_inference.infer_shapes(model, strict_mode=True)
39
+
40
+ with self.assertRaises(RuntimeError):
41
+ version_converter.convert_version(model, to_opset)
42
+
43
+ def _test_op_conversion(
44
+ self,
45
+ op: str,
46
+ from_opset: int,
47
+ input_shapes: Sequence[Sequence[int | None] | str] = ((3, 4, 5),),
48
+ output_shapes: Sequence[Sequence[int | None]] = ((3, 4, 5),),
49
+ input_types: Sequence[Any] | None = None,
50
+ output_types: Sequence[Any] | None = None,
51
+ initializer: Sequence[Any] = (),
52
+ attrs: dict[str, Any] | None = None,
53
+ seq_inputs: Sequence[int] = (),
54
+ seq_outputs: Sequence[int] = (),
55
+ optional_inputs: Sequence[int] = (),
56
+ optional_outputs: Sequence[int] = (),
57
+ is_upgrade: bool = True,
58
+ ) -> None:
59
+ """Test conversion.
60
+
61
+ Args:
62
+ op: A string representing the name of the operator to test.
63
+ from_opset: An integer representing the lowest opset version to convert.
64
+ input_shapes: A sequence of tuples or strings representing the shapes of the input tensors.
65
+ The default value is ((3, 4, 5),).
66
+ output_shapes: A sequence of tuples representing the shapes of the output tensors.
67
+ The default value is ((3, 4, 5),).
68
+ input_types: An optional sequence of types representing the data types of the input tensors.
69
+ output_types: An optional sequence of types representing the data types of the output tensors.
70
+ initializer: A sequence of values representing the initial values of the input tensors.
71
+ attrs: An optional dictionary of attributes for the operator.
72
+ seq_inputs: A sequence of integers representing the indices of the input tensors that are sequences.
73
+ seq_outputs: A sequence of integers representing the indices of the output tensors that are sequences.
74
+ optional_inputs: A sequence of integers representing the indices of the input tensors that are optional.
75
+ optional_outputs: A sequence of integers representing the indices of the output tensors that are optional.
76
+ is_upgrade: A boolean value indicating whether to run the version converter from from_opset to
77
+ the most recent opset version (True) or from the most recent opset version to from_opset (False).
78
+ The default value is True. In both cases, runs checker and shape inference on the final model.
79
+ """
80
+ if attrs is None:
81
+ attrs = {}
82
+
83
+ n_inputs = len(input_shapes)
84
+ letters = list(string.ascii_lowercase)[:n_inputs]
85
+ input_names = [
86
+ letter if shape != "" else ""
87
+ for (letter, shape) in zip(letters, input_shapes)
88
+ ]
89
+ if input_types is None:
90
+ input_types = [TensorProto.FLOAT] * n_inputs
91
+ is_sequence = [0 if id not in seq_inputs else 1 for id in range(n_inputs)]
92
+ is_optional = [0 if id not in optional_inputs else 1 for id in range(n_inputs)]
93
+ # turn empty strings into [0] to ease type analysis, even though those entries
94
+ # will be ignored
95
+ input_shapes_cast = cast(
96
+ "list[list[int]]",
97
+ [[0] if isinstance(shape, str) else shape for shape in input_shapes],
98
+ )
99
+ inputs: list[ValueInfoProto] = []
100
+ for name, ttype, shape, is_seq, is_opt in zip(
101
+ input_names, input_types, input_shapes_cast, is_sequence, is_optional
102
+ ):
103
+ if name != "":
104
+ if is_seq:
105
+ inputs += [
106
+ helper.make_tensor_sequence_value_info(name, ttype, shape)
107
+ ]
108
+ elif is_opt:
109
+ type_proto = helper.make_tensor_type_proto(ttype, shape)
110
+ optional_type_proto = helper.make_optional_type_proto(type_proto)
111
+ inputs += [helper.make_value_info(name, optional_type_proto)]
112
+ else:
113
+ inputs += [helper.make_tensor_value_info(name, ttype, shape)]
114
+
115
+ n_outputs = len(output_shapes)
116
+ output_names = list(string.ascii_lowercase)[n_inputs : n_inputs + n_outputs]
117
+ if output_types is None:
118
+ output_types = [TensorProto.FLOAT] * n_outputs
119
+ is_sequence = [0 if id not in seq_outputs else 1 for id in range(n_outputs)]
120
+ is_optional = [
121
+ 0 if id not in optional_outputs else 1 for id in range(n_outputs)
122
+ ]
123
+ output_shapes_cast = cast(
124
+ "list[list[int]]",
125
+ [[0] if isinstance(shape, str) else shape for shape in output_shapes],
126
+ )
127
+ outputs: list[ValueInfoProto] = []
128
+ for name, ttype, shape, is_seq, is_opt in zip(
129
+ output_names, output_types, output_shapes_cast, is_sequence, is_optional
130
+ ):
131
+ if is_seq:
132
+ outputs += [helper.make_tensor_sequence_value_info(name, ttype, shape)]
133
+ elif is_opt:
134
+ type_proto = helper.make_tensor_type_proto(ttype, shape)
135
+ optional_type_proto = helper.make_optional_type_proto(type_proto)
136
+ outputs += [helper.make_value_info(name, optional_type_proto)]
137
+ else:
138
+ outputs += [helper.make_tensor_value_info(name, ttype, shape)]
139
+
140
+ node = helper.make_node(op, input_names, output_names, **attrs)
141
+ graph = helper.make_graph([node], op, inputs, outputs, initializer)
142
+ start_opset = from_opset if is_upgrade else LATEST_OPSET
143
+ end_opset = LATEST_OPSET if is_upgrade else from_opset
144
+ original = helper.make_model(
145
+ graph,
146
+ producer_name="test",
147
+ opset_imports=[helper.make_opsetid("", start_opset)],
148
+ )
149
+ self._test_model_conversion(end_opset, original)
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_downgrade_test.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ONNX Project Contributors
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from __future__ import annotations
5
+
6
+ import unittest
7
+
8
+ import automatic_conversion_test_base
9
+ import numpy as np
10
+ import parameterized
11
+
12
+ import onnx
13
+ from onnx import helper
14
+
15
+ #####################################################################################
16
+ # Every test calls _test_op_conversion to downgrade a model from the most recent opset version
17
+ # to a early version and runs checker + shape inference on the downgraded model.
18
+ ####################################################################################
19
+
20
+
21
+ class TestAutomaticDowngrade(automatic_conversion_test_base.TestAutomaticConversion):
22
+ def _test_op_downgrade(self, op: str, *args, **kwargs):
23
+ self._test_op_conversion(op, *args, **kwargs, is_upgrade=False)
24
+
25
+ @parameterized.parameterized.expand(
26
+ [
27
+ "ReduceL1",
28
+ "ReduceL2",
29
+ "ReduceLogSum",
30
+ "ReduceLogSumExp",
31
+ "ReduceMean",
32
+ "ReduceMax",
33
+ "ReduceMin",
34
+ "ReduceProd",
35
+ "ReduceSum",
36
+ "ReduceSumSquare",
37
+ ]
38
+ )
39
+ def test_reduce_ops(self, op) -> None:
40
+ # TODO: need to add test cases for missing axes input which depends on this pr:
41
+ # https://github.com/onnx/onnx/pull/5613
42
+ axes = helper.make_tensor(
43
+ "b", onnx.TensorProto.INT64, dims=[3], vals=np.array([0, 1, 2])
44
+ )
45
+ self._test_op_downgrade(
46
+ op,
47
+ from_opset=13,
48
+ input_shapes=[[3, 4, 5], [3]],
49
+ output_shapes=[[1, 1, 1]],
50
+ input_types=[onnx.TensorProto.FLOAT, onnx.TensorProto.INT64],
51
+ initializer=[axes],
52
+ )
53
+
54
+ def test_dft20_no_axis(self) -> None:
55
+ self._test_model_conversion(
56
+ to_opset=19,
57
+ model="""
58
+ <ir_version: 9, opset_import: [ "" : 20]>
59
+ dft_no_axis (float[N, M, 1] x) => (float[N, M, 2] y)
60
+ {
61
+ y = DFT (x)
62
+ }
63
+ """,
64
+ )
65
+
66
+ def test_dft20_initializer_axis(self) -> None:
67
+ self._test_model_conversion(
68
+ to_opset=19,
69
+ model="""
70
+ <ir_version: 9, opset_import: [ "" : 20]>
71
+ dft_no_axis (float[N, M, 1] x, int64 dft_length) => (float[N, K, 2] y)
72
+ <int64 axis = {1}>
73
+ {
74
+ y = DFT (x, dft_length, axis)
75
+ }
76
+ """,
77
+ )
78
+
79
+ def test_dft20_constant_axis(self) -> None:
80
+ self._test_model_conversion(
81
+ to_opset=19,
82
+ model="""
83
+ <ir_version: 9, opset_import: [ "" : 20]>
84
+ dft_no_axis (float[N, M, 1] x, int64 dft_length) => (float[N, K, 2] y)
85
+ {
86
+ axis = Constant <value = int64{1}>()
87
+ y = DFT (x, dft_length, axis)
88
+ }
89
+ """,
90
+ )
91
+
92
+ def test_dft20_unknown_axis(self) -> None:
93
+ self._test_model_conversion_fails(
94
+ to_opset=19,
95
+ model="""
96
+ <ir_version: 9, opset_import: [ "" : 20]>
97
+ dft_no_axis (float[N, M, 1] x, int64 dft_length, int64 axis) => (float[P, K, 2] y)
98
+ {
99
+ y = DFT (x, dft_length, axis)
100
+ }
101
+ """,
102
+ )
103
+
104
+
105
+ if __name__ == "__main__":
106
+ unittest.main()
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_upgrade_test.py ADDED
@@ -0,0 +1,1964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ONNX Project Contributors
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from __future__ import annotations
5
+
6
+ import unittest
7
+
8
+ import automatic_conversion_test_base
9
+ import numpy as np
10
+
11
+ import onnx
12
+ from onnx import TensorProto, helper
13
+
14
+ #####################################################################################
15
+ # Every test calls _test_op_conversion to upgrade a model from an initial opset version
16
+ # to the most recent version and runs checker and shape inference on the final upgraded model.
17
+ ####################################################################################
18
+
19
+
20
+ class TestAutomaticUpgrade(automatic_conversion_test_base.TestAutomaticConversion):
21
+ @classmethod
22
+ def setUpClass(cls):
23
+ cls.tested_ops = []
24
+
25
+ def _test_op_upgrade(self, op, *args, **kwargs):
26
+ self.tested_ops.append(op)
27
+ self._test_op_conversion(op, *args, **kwargs, is_upgrade=True)
28
+
29
+ def test_Abs(self) -> None:
30
+ self._test_op_upgrade("Abs", 1, attrs={"consumed_inputs": [0]})
31
+
32
+ def test_Acosh(self) -> None:
33
+ self._test_op_upgrade("Acosh", 9)
34
+
35
+ def test_Acos(self) -> None:
36
+ self._test_op_upgrade("Acos", 7)
37
+
38
+ def test_And(self) -> None:
39
+ # 6->7 adapter is missing
40
+ self._test_op_upgrade(
41
+ "And",
42
+ 7,
43
+ [[2, 3], [2, 3]],
44
+ [[2, 3]],
45
+ [TensorProto.BOOL, TensorProto.BOOL],
46
+ [TensorProto.BOOL],
47
+ )
48
+
49
+ def test_Asinh(self) -> None:
50
+ self._test_op_upgrade("Asinh", 9)
51
+
52
+ def test_Atanh(self) -> None:
53
+ self._test_op_upgrade("Atanh", 9)
54
+
55
+ def test_Add_1(self) -> None:
56
+ self._test_op_upgrade(
57
+ "Add", 1, [[3, 4, 5], [3, 4, 5]], attrs={"consumed_inputs": [0]}
58
+ )
59
+
60
+ def test_Add_2(self) -> None:
61
+ self._test_op_upgrade(
62
+ "Add", 1, [[3, 4, 5], [5]], attrs={"consumed_inputs": [0], "broadcast": 1}
63
+ )
64
+
65
+ def test_Add_3(self) -> None:
66
+ self._test_op_upgrade(
67
+ "Add",
68
+ 1,
69
+ [[3, 4, 5], [3]],
70
+ attrs={"consumed_inputs": [0], "broadcast": 1, "axis": 0},
71
+ )
72
+
73
+ def test_AffineGrid_2D(self) -> None:
74
+ N, _, H, W = 2, 3, 5, 6
75
+ self._test_op_upgrade("AffineGrid", 20, [[N, 2, 3], [4]], [[N, H, W, 2]])
76
+
77
+ def test_AffineGrid_3D(self) -> None:
78
+ N, _, D, H, W = 2, 3, 4, 5, 6
79
+ self._test_op_upgrade("AffineGrid", 20, [[N, 3, 4], [5]], [[N, D, H, W, 3]])
80
+
81
+ def test_ArgMax_1(self) -> None:
82
+ self._test_op_upgrade(
83
+ "ArgMax", 7, [[2, 3, 4]], [[1, 3, 4]], output_types=[TensorProto.INT64]
84
+ )
85
+
86
+ def test_ArgMax_2(self) -> None:
87
+ self._test_op_upgrade(
88
+ "ArgMax",
89
+ 7,
90
+ [[2, 3, 4]],
91
+ [[2, 1, 4]],
92
+ output_types=[TensorProto.INT64],
93
+ attrs={"axis": 1},
94
+ )
95
+
96
+ def test_ArgMin_1(self) -> None:
97
+ self._test_op_upgrade(
98
+ "ArgMin", 7, [[2, 3, 4]], [[1, 3, 4]], output_types=[TensorProto.INT64]
99
+ )
100
+
101
+ def test_ArgMin_2(self) -> None:
102
+ self._test_op_upgrade(
103
+ "ArgMin",
104
+ 7,
105
+ [[2, 3, 4]],
106
+ [[2, 1, 4]],
107
+ output_types=[TensorProto.INT64],
108
+ attrs={"axis": 1},
109
+ )
110
+
111
+ def test_Asin(self) -> None:
112
+ self._test_op_upgrade("Asin", 7)
113
+
114
+ def test_Atan(self) -> None:
115
+ self._test_op_upgrade("Atan", 7)
116
+
117
+ def test_Attention_1(self) -> None:
118
+ self._test_op_upgrade(
119
+ "Attention",
120
+ 23,
121
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
122
+ [[2, 3, 4, 8]],
123
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
124
+ [TensorProto.FLOAT],
125
+ )
126
+
127
+ def test_Attention_2(self) -> None:
128
+ self._test_op_upgrade(
129
+ "Attention",
130
+ 23,
131
+ [[2, 9, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
132
+ [[2, 9, 4, 8]],
133
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
134
+ [TensorProto.FLOAT],
135
+ )
136
+
137
+ def test_Attention_3(self) -> None:
138
+ self._test_op_upgrade(
139
+ "Attention",
140
+ 23,
141
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 10]],
142
+ [[2, 3, 4, 10]],
143
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
144
+ [TensorProto.FLOAT],
145
+ )
146
+
147
+ def test_Attention_4(self) -> None:
148
+ self._test_op_upgrade(
149
+ "Attention",
150
+ 23,
151
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
152
+ [[2, 3, 4, 8]],
153
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
154
+ [TensorProto.FLOAT],
155
+ attrs={"scale": 2.0},
156
+ )
157
+
158
+ def test_Attention_5(self) -> None:
159
+ self._test_op_upgrade(
160
+ "Attention",
161
+ 23,
162
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
163
+ [[2, 3, 4, 8]],
164
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
165
+ [TensorProto.FLOAT],
166
+ attrs={"is_causal": 1},
167
+ )
168
+
169
+ def test_Attention_6(self) -> None:
170
+ self._test_op_upgrade(
171
+ "Attention",
172
+ 23,
173
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8], [4, 6]],
174
+ [[2, 3, 4, 8]],
175
+ [
176
+ TensorProto.FLOAT,
177
+ TensorProto.FLOAT,
178
+ TensorProto.FLOAT,
179
+ TensorProto.FLOAT,
180
+ ],
181
+ [TensorProto.FLOAT],
182
+ )
183
+
184
+ def test_Attention_7(self) -> None:
185
+ self._test_op_upgrade(
186
+ "Attention",
187
+ 23,
188
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8], [4, 6]],
189
+ [[2, 3, 4, 8]],
190
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.BOOL],
191
+ [TensorProto.FLOAT],
192
+ )
193
+
194
+ def test_Attention_8(self) -> None:
195
+ self._test_op_upgrade(
196
+ "Attention",
197
+ 23,
198
+ [[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
199
+ [[2, 3, 4, 8]],
200
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
201
+ [TensorProto.FLOAT],
202
+ attrs={"softcap": 2.0},
203
+ )
204
+
205
+ def test_AveragePool(self) -> None:
206
+ self._test_op_upgrade(
207
+ "AveragePool",
208
+ 1,
209
+ [[1, 1, 5, 5]],
210
+ [[1, 1, 4, 4]],
211
+ attrs={"kernel_shape": [2, 2]},
212
+ )
213
+
214
+ def test_Bernoulli(self) -> None:
215
+ self._test_op_upgrade("Bernoulli", 15)
216
+
217
+ def test_BitShift(self) -> None:
218
+ self._test_op_upgrade(
219
+ "BitShift",
220
+ 11,
221
+ [[2, 3], [2, 3]],
222
+ [[2, 3]],
223
+ [TensorProto.UINT8, TensorProto.UINT8],
224
+ [TensorProto.UINT8],
225
+ attrs={"direction": "RIGHT"},
226
+ )
227
+
228
+ def test_BatchNormalization_1(self) -> None:
229
+ self._test_op_upgrade(
230
+ "BatchNormalization",
231
+ 1,
232
+ [[1, 3], [3], [3], [3], [3]],
233
+ [[1, 3]],
234
+ attrs={"consumed_inputs": [1, 1], "is_test": 1, "spatial": 1},
235
+ )
236
+
237
+ def test_BatchNormalization_2(self) -> None:
238
+ self._test_op_upgrade(
239
+ "BatchNormalization",
240
+ 14,
241
+ [[1, 3], [3], [3], [3], [3]],
242
+ [[1, 3], [3], [3]],
243
+ attrs={"training_mode": 1},
244
+ )
245
+
246
+ def test_Cast(self) -> None:
247
+ # 5->6 adapter is missing
248
+ self._test_op_upgrade(
249
+ "Cast", 6, [[2, 3]], [[2, 3]], [TensorProto.INT64], attrs={"to": 1}
250
+ )
251
+
252
+ def test_Ceil(self) -> None:
253
+ self._test_op_upgrade("Ceil", 1, attrs={"consumed_inputs": [0]})
254
+
255
+ def test_Celu(self) -> None:
256
+ self._test_op_upgrade("Celu", 12)
257
+
258
+ def test_Clip_1(self) -> None:
259
+ self._test_op_upgrade("Clip", 1, attrs={"consumed_inputs": [0]})
260
+
261
+ def test_Clip_2(self) -> None:
262
+ self._test_op_upgrade("Clip", 1, attrs={"consumed_inputs": [0], "min": -1.4})
263
+
264
+ def test_Clip_3(self) -> None:
265
+ self._test_op_upgrade("Clip", 1, attrs={"consumed_inputs": [0], "max": 2.6})
266
+
267
+ def test_Clip_4(self) -> None:
268
+ self._test_op_upgrade(
269
+ "Clip", 1, attrs={"consumed_inputs": [0], "min": -1.4, "max": 2.6}
270
+ )
271
+
272
+ def test_Col2Im_4D(self) -> None:
273
+ self._test_op_upgrade("Col2Im", 18, [[1, 5, 5], [2], [2]], [[1, 1, 5, 5]])
274
+
275
+ def test_Col2Im_5D(self) -> None:
276
+ self._test_op_upgrade("Col2Im", 18, [[1, 10, 12], [3], [3]], [[1, 2, 3, 4, 5]])
277
+
278
+ def test_Compress(self) -> None:
279
+ self._test_op_upgrade(
280
+ "Compress",
281
+ 9,
282
+ [[6, 7], [3]],
283
+ [[3]],
284
+ [TensorProto.FLOAT, TensorProto.BOOL],
285
+ [TensorProto.FLOAT],
286
+ )
287
+
288
+ def test_Concat(self) -> None:
289
+ self._test_op_upgrade("Concat", 1, [[2, 3], [2, 4]], [[2, 7]])
290
+
291
+ def test_constant(self) -> None:
292
+ value = helper.make_tensor(
293
+ "Value",
294
+ TensorProto.FLOAT,
295
+ dims=[3, 4, 5],
296
+ vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
297
+ raw=True,
298
+ )
299
+ self._test_op_upgrade("Constant", 1, [], attrs={"value": value})
300
+
301
+ def test_ConstantOfShape(self) -> None:
302
+ self._test_op_upgrade("ConstantOfShape", 9, [[3]])
303
+
304
+ def test_Conv_1(self) -> None:
305
+ self._test_op_upgrade(
306
+ "Conv", 1, [[1, 3, 5, 5], [4, 3, 2, 2], [4]], [[1, 4, 4, 4]]
307
+ )
308
+
309
+ def test_Conv_2(self) -> None:
310
+ self._test_op_upgrade(
311
+ "Conv", 1, [[1, 3, 5, 5], [4, 3, 2, 2], [4]], [[1, 4, 4, 4]]
312
+ )
313
+
314
+ def test_Conv_3(self) -> None:
315
+ self._test_op_upgrade(
316
+ "Conv",
317
+ 1,
318
+ [[1, 3, 5, 5], [4, 1, 2, 2], [4]],
319
+ [[1, 4, 3, 7]],
320
+ attrs={
321
+ "dilations": [1, 2],
322
+ "group": 3,
323
+ "pads": [0, 1, 2, 3],
324
+ "strides": [2, 1],
325
+ },
326
+ )
327
+
328
+ def test_Convinteger(self) -> None:
329
+ self._test_op_upgrade(
330
+ "ConvInteger",
331
+ 10,
332
+ [[1, 3, 5, 5], [4, 3, 2, 2], [4]],
333
+ [[1, 4, 4, 4]],
334
+ [TensorProto.UINT8, TensorProto.UINT8, TensorProto.UINT8],
335
+ [TensorProto.INT32],
336
+ )
337
+
338
+ def test_ConvTranspose(self) -> None:
339
+ self._test_op_upgrade(
340
+ "ConvTranspose", 1, [[1, 1, 5, 5], [1, 1, 3, 3]], [[1, 1, 7, 7]]
341
+ )
342
+
343
+ def test_DeformConv(self) -> None:
344
+ self._test_op_upgrade(
345
+ "DeformConv",
346
+ 19,
347
+ [[1, 1, 3, 3], [1, 1, 2, 2], [1, 8, 2, 2]],
348
+ [[1, 1, 2, 2]],
349
+ )
350
+
351
+ def test_Cosh(self) -> None:
352
+ self._test_op_upgrade("Cosh", 9)
353
+
354
+ def test_Cos(self) -> None:
355
+ self._test_op_upgrade("Cos", 7)
356
+
357
+ def test_Cumsum(self) -> None:
358
+ self._test_op_upgrade(
359
+ "CumSum",
360
+ 11,
361
+ [[3, 4, 5], []],
362
+ [[3, 4, 5]],
363
+ [TensorProto.FLOAT, TensorProto.INT64],
364
+ )
365
+
366
+ def test_DepthToSpace(self) -> None:
367
+ self._test_op_upgrade(
368
+ "DepthToSpace", 1, [[1, 8, 3, 3]], [[1, 2, 6, 6]], attrs={"blocksize": 2}
369
+ )
370
+
371
+ def test_DequantizeLinear(self) -> None:
372
+ self._test_op_upgrade(
373
+ "DequantizeLinear",
374
+ 10,
375
+ [[2, 3], [], []],
376
+ [[2, 3]],
377
+ [TensorProto.INT8, TensorProto.FLOAT, TensorProto.INT8],
378
+ )
379
+
380
+ def test_Det_1(self) -> None:
381
+ self._test_op_upgrade("Det", 11, [[3, 5, 5]], [[3]])
382
+
383
+ def test_Det_2(self) -> None:
384
+ self._test_op_upgrade("Det", 11, [[5, 5]], [[]])
385
+
386
+ def test_DynamicQuantizeLinear(self) -> None:
387
+ self._test_op_upgrade(
388
+ "DynamicQuantizeLinear",
389
+ 11,
390
+ [[3, 4, 5]],
391
+ [[3, 4, 5], [], []],
392
+ output_types=[TensorProto.UINT8, TensorProto.FLOAT, TensorProto.UINT8],
393
+ )
394
+
395
+ def test_Div(self) -> None:
396
+ self._test_op_upgrade(
397
+ "Div", 1, [[3, 4, 5], [3, 1, 5]], attrs={"consumed_inputs": [0]}
398
+ )
399
+
400
+ def test_Dropout(self) -> None:
401
+ self._test_op_upgrade(
402
+ "Dropout", 1, attrs={"consumed_inputs": [0], "is_test": 1}
403
+ )
404
+
405
+ def test_Einsum_1(self) -> None:
406
+ self._test_op_upgrade(
407
+ "Einsum",
408
+ 12,
409
+ [[3, 4, 5], [3, 5, 6]],
410
+ [[3, 4, 6]],
411
+ attrs={"equation": "bij, bjk -> bik"},
412
+ )
413
+
414
+ def test_Einsum_2(self) -> None:
415
+ self._test_op_upgrade(
416
+ "Einsum", 12, [[4, 5]], [[5, 4]], attrs={"equation": "ij->ji"}
417
+ )
418
+
419
+ def test_Elu(self) -> None:
420
+ self._test_op_upgrade("Elu", 1, attrs={"consumed_inputs": [0]})
421
+
422
+ def test_Equal(self) -> None:
423
+ # 6->7 adapter is missing
424
+ self._test_op_upgrade(
425
+ "Equal", 7, [[2, 3], [2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
426
+ )
427
+
428
+ def test_Erf(self) -> None:
429
+ self._test_op_upgrade("Erf", 9)
430
+
431
+ def test_Exp(self) -> None:
432
+ self._test_op_upgrade("Exp", 1, attrs={"consumed_inputs": [0]})
433
+
434
+ def test_Expand(self) -> None:
435
+ shape = helper.make_tensor(
436
+ "b", TensorProto.INT64, dims=[4], vals=np.array([5, 2, 6, 4])
437
+ )
438
+ self._test_op_upgrade(
439
+ "Expand",
440
+ 8,
441
+ [[2, 1, 4], [4]],
442
+ [[5, 2, 6, 4]],
443
+ [TensorProto.FLOAT, TensorProto.INT64],
444
+ initializer=[shape],
445
+ )
446
+
447
+ def test_EyeLike(self) -> None:
448
+ self._test_op_upgrade("EyeLike", 9, [[4, 5]], [[4, 5]])
449
+
450
+ def test_Flatten(self) -> None:
451
+ self._test_op_upgrade("Flatten", 1, [[3, 4, 5]], [[3, 20]], attrs={"axis": 1})
452
+
453
+ def test_Floor(self) -> None:
454
+ self._test_op_upgrade("Floor", 1, attrs={"consumed_inputs": [0]})
455
+
456
+ def test_Gather(self) -> None:
457
+ self._test_op_upgrade(
458
+ "Gather",
459
+ 1,
460
+ [[3, 4, 5], [6, 7]],
461
+ [[6, 7, 4, 5]],
462
+ [TensorProto.FLOAT, TensorProto.INT64],
463
+ )
464
+
465
+ def test_GatherElements(self) -> None:
466
+ self._test_op_upgrade(
467
+ "GatherElements",
468
+ 11,
469
+ [[3, 4, 5], [6, 7]],
470
+ [[6, 7]],
471
+ [TensorProto.FLOAT, TensorProto.INT64],
472
+ )
473
+
474
+ def test_GatherND(self) -> None:
475
+ self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]])
476
+
477
+ def test_Gelu_approximate_tanh(self) -> None:
478
+ self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"})
479
+
480
+ def test_Gelu(self) -> None:
481
+ self._test_op_upgrade("Gelu", 20)
482
+
483
+ def test_Gemm(self) -> None:
484
+ self._test_op_upgrade("Gemm", 1, [[5, 4], [4, 3], [3]], [[5, 3]])
485
+
486
+ def test_GlobalAveragePool(self) -> None:
487
+ self._test_op_upgrade("GlobalAveragePool", 1, [[1, 3, 10, 10]], [[1, 3, 1, 1]])
488
+
489
+ def test_GlobalMaxPool(self) -> None:
490
+ self._test_op_upgrade("GlobalMaxPool", 1, [[1, 3, 10, 10]], [[1, 3, 1, 1]])
491
+
492
+ def test_GlobalLpPool(self) -> None:
493
+ # 1->2 adapter is missing
494
+ self._test_op_upgrade("GlobalLpPool", 2, [[1, 3, 10, 10]], [[1, 3, 1, 1]])
495
+
496
+ def test_Greater(self) -> None:
497
+ # 6->7 adapter is missing
498
+ self._test_op_upgrade(
499
+ "Greater", 7, [[2, 3], [2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
500
+ )
501
+
502
+ def test_GreaterOrEqual(self) -> None:
503
+ self._test_op_upgrade(
504
+ "GreaterOrEqual",
505
+ 12,
506
+ [[2, 3], [2, 3]],
507
+ [[2, 3]],
508
+ output_types=[TensorProto.BOOL],
509
+ )
510
+
511
+ def test_GridSample(self) -> None:
512
+ self._test_op_upgrade(
513
+ "GridSample",
514
+ 16,
515
+ [[1, 1, 3, 3], [1, 3, 3, 2]],
516
+ [[1, 1, 3, 3]],
517
+ input_types=[TensorProto.FLOAT, TensorProto.FLOAT],
518
+ output_types=[TensorProto.FLOAT],
519
+ attrs={"mode": "nearest", "padding_mode": "border", "align_corners": 1},
520
+ )
521
+
522
+ def test_GRU_1(self) -> None:
523
+ # 2->3, 6->7 adapters are missing
524
+ self._test_op_upgrade(
525
+ "GRU",
526
+ 7,
527
+ [[5, 3, 4], [1, 18, 4], [1, 18, 4]],
528
+ [[5, 1, 3, 6], [1, 3, 6]],
529
+ attrs={"hidden_size": 6},
530
+ )
531
+
532
+ def test_GRU_2(self) -> None:
533
+ # 2->3, 6->7 adapters are missing
534
+ self._test_op_upgrade(
535
+ "GRU",
536
+ 7,
537
+ [[5, 3, 4], [2, 18, 4], [2, 18, 4]],
538
+ [[5, 2, 3, 6], [2, 3, 6]],
539
+ attrs={"hidden_size": 6, "direction": "bidirectional"},
540
+ )
541
+
542
+ def test_GRU_3(self) -> None:
543
+ # 2->3, 6->7 adapters are missing
544
+ self._test_op_upgrade(
545
+ "GRU",
546
+ 7,
547
+ [[5, 3, 4], [1, 18, 4], [1, 18, 4], [1, 24], [5], [1, 5, 6]],
548
+ [[5, 1, 3, 6], [1, 3, 6]],
549
+ [
550
+ TensorProto.FLOAT,
551
+ TensorProto.FLOAT,
552
+ TensorProto.FLOAT,
553
+ TensorProto.FLOAT,
554
+ TensorProto.INT64,
555
+ TensorProto.FLOAT,
556
+ ],
557
+ attrs={"hidden_size": 6},
558
+ )
559
+
560
+ def test_HardSigmoid(self) -> None:
561
+ self._test_op_upgrade("HardSigmoid", 1, attrs={"consumed_inputs": [0]})
562
+
563
+ def test_HardSwish(self) -> None:
564
+ self._test_op_upgrade("HardSwish", 14)
565
+
566
+ def test_Hardmax(self) -> None:
567
+ self._test_op_upgrade("Hardmax", 1)
568
+
569
+ def test_Identity(self) -> None:
570
+ self._test_op_upgrade("Identity", 1)
571
+
572
+ def test_If(self) -> None:
573
+ sub_output = [
574
+ helper.make_tensor_value_info("out", TensorProto.FLOAT, [3, 4, 5])
575
+ ]
576
+ then_tensor = helper.make_tensor(
577
+ "Value",
578
+ TensorProto.FLOAT,
579
+ dims=[3, 4, 5],
580
+ vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
581
+ raw=True,
582
+ )
583
+ then_node = helper.make_node("Constant", [], ["out"], value=then_tensor)
584
+ then_graph = helper.make_graph([then_node], "then_graph", [], sub_output, [])
585
+ else_tensor = helper.make_tensor(
586
+ "Value",
587
+ TensorProto.FLOAT,
588
+ dims=[3, 4, 5],
589
+ vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
590
+ raw=True,
591
+ )
592
+ else_node = helper.make_node("Constant", [], ["out"], value=else_tensor)
593
+ else_graph = helper.make_graph([else_node], "else_graph", [], sub_output, [])
594
+ self._test_op_upgrade(
595
+ "If",
596
+ 1,
597
+ [[0]],
598
+ [[3, 4, 5]],
599
+ [TensorProto.BOOL],
600
+ attrs={"then_branch": then_graph, "else_branch": else_graph},
601
+ )
602
+
603
+ def test_ImageDecoder(self) -> None:
604
+ self._test_op_upgrade(
605
+ "ImageDecoder",
606
+ 20,
607
+ [[None]],
608
+ [[None, None, 3]],
609
+ input_types=[TensorProto.UINT8],
610
+ output_types=[TensorProto.UINT8],
611
+ )
612
+
613
+ def test_InstanceNormalization(self) -> None:
614
+ self._test_op_upgrade(
615
+ "InstanceNormalization",
616
+ 1,
617
+ [[1, 3], [3], [3]],
618
+ [[1, 3]],
619
+ attrs={"consumed_inputs": [0]},
620
+ )
621
+
622
+ def test_IsInf(self) -> None:
623
+ self._test_op_upgrade(
624
+ "IsInf", 10, [[2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
625
+ )
626
+
627
+ def test_IsNaN(self) -> None:
628
+ self._test_op_upgrade(
629
+ "IsNaN", 9, [[2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
630
+ )
631
+
632
+ def test_LeakyRelu(self) -> None:
633
+ self._test_op_upgrade("LeakyRelu", 1, attrs={"consumed_inputs": [0]})
634
+
635
+ def test_Less(self) -> None:
636
+ # 6->7 adapter is missing
637
+ self._test_op_upgrade(
638
+ "Less", 7, [[2, 3], [2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
639
+ )
640
+
641
+ def test_LessOrEqual(self) -> None:
642
+ self._test_op_upgrade(
643
+ "LessOrEqual",
644
+ 12,
645
+ [[2, 3], [2, 3]],
646
+ [[2, 3]],
647
+ output_types=[TensorProto.BOOL],
648
+ )
649
+
650
+ def test_Log(self) -> None:
651
+ self._test_op_upgrade("Log", 1, attrs={"consumed_inputs": [0]})
652
+
653
+ def test_LogSoftmax(self) -> None:
654
+ self._test_op_upgrade("LogSoftmax", 1)
655
+
656
+ def test_Loop_1(self) -> None:
657
+ iter_count = onnx.helper.make_tensor_value_info(
658
+ "iter_count", onnx.TensorProto.INT64, []
659
+ )
660
+ cond_in = onnx.helper.make_tensor_value_info(
661
+ "cond_in", onnx.TensorProto.BOOL, []
662
+ )
663
+ x_in = onnx.helper.make_tensor_value_info("x_in", onnx.TensorProto.FLOAT, [1])
664
+ cond_out = onnx.helper.make_tensor_value_info(
665
+ "cond_out", onnx.TensorProto.BOOL, []
666
+ )
667
+ x_out = onnx.helper.make_tensor_value_info("x_out", onnx.TensorProto.FLOAT, [1])
668
+ x_scan = onnx.helper.make_tensor_value_info(
669
+ "x_scan", onnx.TensorProto.FLOAT, [1]
670
+ )
671
+ const = onnx.helper.make_node(
672
+ "Constant",
673
+ inputs=[],
674
+ outputs=["one"],
675
+ value=onnx.helper.make_tensor(
676
+ name="value",
677
+ data_type=onnx.TensorProto.FLOAT,
678
+ dims=[1],
679
+ vals=np.array([1]).astype(np.float32).astype(float),
680
+ ),
681
+ )
682
+ add = onnx.helper.make_node("Add", inputs=["x_in", "one"], outputs=["x_out"])
683
+ id_1 = onnx.helper.make_node("Identity", inputs=["x_out"], outputs=["x_scan"])
684
+ id_2 = onnx.helper.make_node(
685
+ "Identity", inputs=["cond_in"], outputs=["cond_out"]
686
+ )
687
+ loop_body = onnx.helper.make_graph(
688
+ [const, add, id_1, id_2],
689
+ "loop_body",
690
+ [iter_count, cond_in, x_in],
691
+ [cond_out, x_out, x_scan],
692
+ )
693
+ self._test_op_upgrade(
694
+ "Loop",
695
+ 1,
696
+ [[], "", [1]],
697
+ [[1], [5, 1]],
698
+ [TensorProto.INT64, TensorProto.BOOL, TensorProto.FLOAT],
699
+ attrs={"body": loop_body},
700
+ )
701
+
702
+ def test_Loop_2(self) -> None:
703
+ iter_count = onnx.helper.make_tensor_value_info(
704
+ "iter_count", onnx.TensorProto.INT64, []
705
+ )
706
+ cond_in = onnx.helper.make_tensor_value_info(
707
+ "cond_in", onnx.TensorProto.BOOL, []
708
+ )
709
+ x_in = onnx.helper.make_tensor_value_info(
710
+ "x_in", onnx.TensorProto.FLOAT, [2, 1]
711
+ )
712
+ cond_out = onnx.helper.make_tensor_value_info(
713
+ "cond_out", onnx.TensorProto.BOOL, []
714
+ )
715
+ x_out = onnx.helper.make_tensor_value_info(
716
+ "x_out", onnx.TensorProto.FLOAT, [2, 1]
717
+ )
718
+ squeeze = onnx.helper.make_node(
719
+ "Squeeze", inputs=["x_in"], outputs=["squeeze_out"], axes=[1]
720
+ )
721
+ unsqueeze = onnx.helper.make_node(
722
+ "Unsqueeze", inputs=["squeeze_out"], outputs=["x_out"], axes=[1]
723
+ )
724
+ identity = onnx.helper.make_node(
725
+ "Identity", inputs=["cond_in"], outputs=["cond_out"]
726
+ )
727
+ loop_body = onnx.helper.make_graph(
728
+ [squeeze, unsqueeze, identity],
729
+ "loop_body",
730
+ [iter_count, cond_in, x_in],
731
+ [cond_out, x_out],
732
+ )
733
+ self._test_op_upgrade(
734
+ "Loop",
735
+ 12,
736
+ [[], "", [2, 1]],
737
+ [[2, 1]],
738
+ [TensorProto.INT64, TensorProto.BOOL, TensorProto.FLOAT],
739
+ attrs={"body": loop_body},
740
+ )
741
+
742
+ def test_LpNormalization(self) -> None:
743
+ self._test_op_upgrade("LpNormalization", 1)
744
+
745
+ def test_LpPool(self) -> None:
746
+ # 1->2 adapter is missing
747
+ self._test_op_upgrade(
748
+ "LpPool", 2, [[1, 1, 5, 5]], [[1, 1, 4, 4]], attrs={"kernel_shape": [2, 2]}
749
+ )
750
+
751
+ def test_LRN_1(self) -> None:
752
+ self._test_op_upgrade("LRN", 1, attrs={"size": 3})
753
+
754
+ def test_LRN_2(self) -> None:
755
+ self._test_op_upgrade(
756
+ "LRN", 1, [[2, 3, 4, 5]], [[2, 3, 4, 5]], attrs={"size": 3}
757
+ )
758
+
759
+ def test_LSTM_1(self) -> None:
760
+ # 6->7 adapter is missing
761
+ self._test_op_upgrade(
762
+ "LSTM",
763
+ 7,
764
+ [[5, 3, 4], [1, 24, 4], [1, 24, 4]],
765
+ [[5, 1, 3, 6], [1, 3, 6], [1, 3, 6]],
766
+ attrs={"hidden_size": 6},
767
+ )
768
+
769
+ def test_LSTM_2(self) -> None:
770
+ # 6->7 adapter is missing
771
+ self._test_op_upgrade(
772
+ "LSTM",
773
+ 7,
774
+ [[5, 3, 4], [2, 24, 4], [2, 24, 4]],
775
+ [[5, 2, 3, 6], [2, 3, 6], [2, 3, 6]],
776
+ attrs={"hidden_size": 6, "direction": "bidirectional"},
777
+ )
778
+
779
+ def test_LSTM_3(self) -> None:
780
+ # 6->7 adapter is missing
781
+ self._test_op_upgrade(
782
+ "LSTM",
783
+ 7,
784
+ [
785
+ [5, 3, 4],
786
+ [1, 24, 4],
787
+ [1, 24, 4],
788
+ [1, 48],
789
+ [5],
790
+ [1, 5, 6],
791
+ [1, 5, 6],
792
+ [1, 18],
793
+ ],
794
+ [[5, 1, 3, 6], [1, 3, 6], [1, 3, 6]],
795
+ [
796
+ TensorProto.FLOAT,
797
+ TensorProto.FLOAT,
798
+ TensorProto.FLOAT,
799
+ TensorProto.FLOAT,
800
+ TensorProto.INT64,
801
+ TensorProto.FLOAT,
802
+ TensorProto.FLOAT,
803
+ TensorProto.FLOAT,
804
+ ],
805
+ attrs={"hidden_size": 6},
806
+ )
807
+
808
+ def test_MatMul_1(self) -> None:
809
+ self._test_op_upgrade("MatMul", 1, [[2, 3], [3, 4]], [[2, 4]])
810
+
811
+ def test_MatMul_2(self) -> None:
812
+ self._test_op_upgrade("MatMul", 1, [[5, 2, 3], [5, 3, 4]], [[5, 2, 4]])
813
+
814
+ def test_MatMulInteger_1(self) -> None:
815
+ self._test_op_upgrade(
816
+ "MatMulInteger",
817
+ 10,
818
+ [[2, 3], [3, 4]],
819
+ [[2, 4]],
820
+ [TensorProto.INT8, TensorProto.INT8],
821
+ [TensorProto.INT32],
822
+ )
823
+
824
+ def test_MatMulInteger_2(self) -> None:
825
+ self._test_op_upgrade(
826
+ "MatMulInteger",
827
+ 10,
828
+ [[2, 3], [3, 4], [], []],
829
+ [[2, 4]],
830
+ [TensorProto.INT8, TensorProto.INT8, TensorProto.INT8, TensorProto.INT8],
831
+ [TensorProto.INT32],
832
+ )
833
+
834
+ def test_MatMulInteger_3(self) -> None:
835
+ self._test_op_upgrade(
836
+ "MatMulInteger",
837
+ 10,
838
+ [[2, 3], [3, 4], [2], [4]],
839
+ [[2, 4]],
840
+ [TensorProto.INT8, TensorProto.INT8, TensorProto.INT8, TensorProto.INT8],
841
+ [TensorProto.INT32],
842
+ )
843
+
844
+ def test_Max(self) -> None:
845
+ self._test_op_upgrade(
846
+ "Max",
847
+ 1,
848
+ [[2, 3, 4], [2, 3, 4]],
849
+ [[2, 3, 4]],
850
+ attrs={"consumed_inputs": [0]},
851
+ )
852
+
853
+ def test_MaxPool_1(self) -> None:
854
+ self._test_op_upgrade(
855
+ "MaxPool", 1, [[1, 1, 5, 5]], [[1, 1, 4, 4]], attrs={"kernel_shape": [2, 2]}
856
+ )
857
+
858
+ def test_MaxPool_2(self) -> None:
859
+ self._test_op_upgrade(
860
+ "MaxPool",
861
+ 8,
862
+ [[1, 1, 5, 5]],
863
+ [[1, 1, 4, 4], [1, 1, 4, 4]],
864
+ output_types=[TensorProto.FLOAT, TensorProto.INT64],
865
+ attrs={"kernel_shape": [2, 2]},
866
+ )
867
+
868
+ def test_MaxRoiPool(self) -> None:
869
+ self._test_op_upgrade(
870
+ "MaxRoiPool",
871
+ 1,
872
+ [[2, 3, 20, 20], [4, 5]],
873
+ [[4, 3, 3, 3]],
874
+ attrs={"pooled_shape": [3, 3]},
875
+ )
876
+
877
+ def test_MaxUnpool(self) -> None:
878
+ self._test_op_upgrade(
879
+ "MaxUnpool",
880
+ 9,
881
+ [[1, 1, 5, 5], [1, 1, 5, 5]],
882
+ [[1, 1, 6, 6]],
883
+ [TensorProto.FLOAT, TensorProto.INT64],
884
+ attrs={"kernel_shape": [2, 2]},
885
+ )
886
+
887
+ def test_Mean(self) -> None:
888
+ self._test_op_upgrade(
889
+ "Mean",
890
+ 1,
891
+ [[2, 3, 4], [2, 3, 4]],
892
+ [[2, 3, 4]],
893
+ attrs={"consumed_inputs": [0]},
894
+ )
895
+
896
+ def test_MeanVarianceNormalization(self) -> None:
897
+ self._test_op_upgrade("MeanVarianceNormalization", 9, attrs={"axes": [1, 2]})
898
+
899
+ def test_Min(self) -> None:
900
+ self._test_op_upgrade(
901
+ "Min",
902
+ 1,
903
+ [[2, 3, 4], [2, 3, 4]],
904
+ [[2, 3, 4]],
905
+ attrs={"consumed_inputs": [0]},
906
+ )
907
+
908
+ def test_Mish(self) -> None:
909
+ self._test_op_upgrade("Mish", 18)
910
+
911
+ def test_Mod_1(self) -> None:
912
+ self._test_op_upgrade("Mod", 10, [[2, 3], [2, 3]], [[2, 3]])
913
+
914
+ def test_Mod_2(self) -> None:
915
+ self._test_op_upgrade("Mod", 10, [[2, 3], [2, 3]], [[2, 3]], attrs={"fmod": 1})
916
+
917
+ def test_Mul(self) -> None:
918
+ self._test_op_upgrade(
919
+ "Mul",
920
+ 1,
921
+ [[2, 3, 4], [2, 1, 4]],
922
+ [[2, 3, 4]],
923
+ attrs={"consumed_inputs": [0]},
924
+ )
925
+
926
+ def test_Multinomial(self) -> None:
927
+ self._test_op_upgrade(
928
+ "Multinomial",
929
+ 7,
930
+ [[3, 5]],
931
+ [[3, 7]],
932
+ output_types=[TensorProto.INT32],
933
+ attrs={"sample_size": 7},
934
+ )
935
+
936
+ def test_Neg(self) -> None:
937
+ self._test_op_upgrade("Neg", 1, attrs={"consumed_inputs": [0]})
938
+
939
+ def test_NegativeLogLikelihoodLoss_1(self) -> None:
940
+ self._test_op_upgrade(
941
+ "NegativeLogLikelihoodLoss",
942
+ 12,
943
+ [[3, 4, 5], [3, 5]],
944
+ [[]],
945
+ [TensorProto.FLOAT, TensorProto.INT64],
946
+ )
947
+
948
+ def test_NegativeLogLikelihoodLoss_2(self) -> None:
949
+ self._test_op_upgrade(
950
+ "NegativeLogLikelihoodLoss",
951
+ 12,
952
+ [[3, 4, 5], [3, 5], [4]],
953
+ [[]],
954
+ [TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
955
+ )
956
+
957
+ def test_NonMaxSuppression(self) -> None:
958
+ self._test_op_upgrade(
959
+ "NonMaxSuppression",
960
+ 10,
961
+ [[2, 3, 4], [3, 5, 6]],
962
+ [[2, 3]],
963
+ output_types=[TensorProto.INT64],
964
+ )
965
+
966
+ def test_NonZero(self) -> None:
967
+ self._test_op_upgrade(
968
+ "NonZero", 9, [[3, 3]], [[2, 4]], output_types=[TensorProto.INT64]
969
+ )
970
+
971
+ def test_Not(self) -> None:
972
+ self._test_op_upgrade(
973
+ "Not", 1, [[2, 3]], [[2, 3]], [TensorProto.BOOL], [TensorProto.BOOL]
974
+ )
975
+
976
+ def test_OneHot(self) -> None:
977
+ self._test_op_upgrade("OneHot", 9, [[3, 4, 5], [], [2]], [[3, 4, 5, 6]])
978
+
979
+ def test_Or(self) -> None:
980
+ # 6->7 adapter is missing
981
+ self._test_op_upgrade(
982
+ "Or",
983
+ 7,
984
+ [[2, 3], [2, 3]],
985
+ [[2, 3]],
986
+ [TensorProto.BOOL, TensorProto.BOOL],
987
+ [TensorProto.BOOL],
988
+ )
989
+
990
+ def test_Pad(self) -> None:
991
+ # 1->2 adapter is missing
992
+ self._test_op_upgrade(
993
+ "Pad", 2, [[3, 4]], [[5, 8]], attrs={"pads": [1, 2, 1, 2], "value": 1.5}
994
+ )
995
+
996
+ def test_Pow(self) -> None:
997
+ self._test_op_upgrade("Pow", 1, [[2, 3, 4], [2, 3, 4]], [[2, 3, 4]])
998
+
999
+ def test_PRelu(self) -> None:
1000
+ self._test_op_upgrade(
1001
+ "PRelu",
1002
+ 1,
1003
+ [[2, 3, 4], [2, 3, 4]],
1004
+ [[2, 3, 4]],
1005
+ attrs={"consumed_inputs": [0]},
1006
+ )
1007
+
1008
+ def test_QLinearConv(self) -> None:
1009
+ self._test_op_upgrade(
1010
+ "QLinearConv",
1011
+ 10,
1012
+ [[1, 3, 5, 5], [], [], [4, 3, 2, 2], [], [], [], []],
1013
+ [[1, 4, 4, 4]],
1014
+ )
1015
+
1016
+ def test_QLinearMatMul(self) -> None:
1017
+ self._test_op_upgrade(
1018
+ "QLinearMatMul", 10, [[2, 3], [], [], [3, 4], [], [], [], []], [[2, 4]]
1019
+ )
1020
+
1021
+ def test_QuantizeLinear(self) -> None:
1022
+ self._test_op_upgrade(
1023
+ "QuantizeLinear",
1024
+ 10,
1025
+ [[3, 4, 5], [], []],
1026
+ [[3, 4, 5]],
1027
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.UINT8],
1028
+ [TensorProto.UINT8],
1029
+ )
1030
+
1031
+ def test_RandomNormal(self) -> None:
1032
+ self._test_op_upgrade(
1033
+ "RandomNormal", 1, [], [[3, 4, 5]], attrs={"shape": [3, 4, 5]}
1034
+ )
1035
+
1036
+ def test_RandomNormalLike(self) -> None:
1037
+ like = helper.make_tensor(
1038
+ "a",
1039
+ TensorProto.FLOAT,
1040
+ dims=[3, 4, 5],
1041
+ vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
1042
+ raw=True,
1043
+ )
1044
+ self._test_op_upgrade(
1045
+ "RandomNormalLike", 1, [[3, 4, 5]], [[3, 4, 5]], initializer=[like]
1046
+ )
1047
+
1048
+ def test_RandomUniform(self) -> None:
1049
+ self._test_op_upgrade(
1050
+ "RandomUniform", 1, [], [[3, 4, 5]], attrs={"shape": [3, 4, 5]}
1051
+ )
1052
+
1053
+ def test_RandomUniformLike(self) -> None:
1054
+ like = helper.make_tensor(
1055
+ "a",
1056
+ TensorProto.FLOAT,
1057
+ dims=[3, 4, 5],
1058
+ vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
1059
+ raw=True,
1060
+ )
1061
+ self._test_op_upgrade(
1062
+ "RandomUniformLike", 1, [[3, 4, 5]], [[3, 4, 5]], initializer=[like]
1063
+ )
1064
+
1065
+ def test_Range(self) -> None:
1066
+ start = helper.make_tensor("a", TensorProto.FLOAT, dims=[], vals=np.array([0]))
1067
+ end = helper.make_tensor("b", TensorProto.FLOAT, dims=[], vals=np.array([12]))
1068
+ step = helper.make_tensor("c", TensorProto.FLOAT, dims=[], vals=np.array([2]))
1069
+ self._test_op_upgrade(
1070
+ "Range", 11, [[], [], []], [[6]], initializer=[start, end, step]
1071
+ )
1072
+
1073
+ def test_Reciprocal(self) -> None:
1074
+ self._test_op_upgrade("Reciprocal", 1, attrs={"consumed_inputs": [0]})
1075
+
1076
+ def test_ReduceL1(self) -> None:
1077
+ self._test_op_upgrade("ReduceL1", 1, [[3, 4, 5]], [[1, 1, 1]])
1078
+
1079
+ def test_ReduceL2(self) -> None:
1080
+ self._test_op_upgrade("ReduceL2", 1, [[3, 4, 5]], [[1, 1, 1]])
1081
+
1082
+ def test_ReduceLogSum(self) -> None:
1083
+ self._test_op_upgrade("ReduceLogSum", 1, [[3, 4, 5]], [[1, 1, 1]])
1084
+
1085
+ def test_ReduceLogSumExp(self) -> None:
1086
+ self._test_op_upgrade("ReduceLogSumExp", 1, [[3, 4, 5]], [[1, 1, 1]])
1087
+
1088
+ def test_ReduceMean(self) -> None:
1089
+ self._test_op_upgrade("ReduceMean", 1, [[3, 4, 5]], [[1, 1, 1]])
1090
+
1091
+ def test_ReduceMax(self) -> None:
1092
+ self._test_op_upgrade("ReduceMax", 1, [[3, 4, 5]], [[1, 1, 1]])
1093
+
1094
+ def test_ReduceMin(self) -> None:
1095
+ self._test_op_upgrade("ReduceMin", 1, [[3, 4, 5]], [[1, 1, 1]])
1096
+
1097
+ def test_ReduceProd(self) -> None:
1098
+ self._test_op_upgrade("ReduceProd", 1, [[3, 4, 5]], [[1, 1, 1]])
1099
+
1100
+ def test_ReduceSum(self) -> None:
1101
+ self._test_op_upgrade("ReduceSum", 1, [[3, 4, 5]], [[1, 1, 1]])
1102
+
1103
+ def test_ReduceSumSquare(self) -> None:
1104
+ self._test_op_upgrade("ReduceSumSquare", 1, [[3, 4, 5]], [[1, 1, 1]])
1105
+
1106
+ def test_Relu(self) -> None:
1107
+ self._test_op_upgrade("Relu", 1, attrs={"consumed_inputs": [0]})
1108
+
1109
+ def test_Reshape(self) -> None:
1110
+ self._test_op_upgrade(
1111
+ "Reshape",
1112
+ 1,
1113
+ [[3, 4, 5]],
1114
+ [[3, 10, 2]],
1115
+ attrs={"consumed_inputs": [0], "shape": [3, 10, 2]},
1116
+ )
1117
+
1118
+ def test_Resize(self) -> None:
1119
+ self._test_op_upgrade("Resize", 10, [[3, 4, 5], [3]], [[3, 8, 15]])
1120
+
1121
+ def test_ReverseSequence(self) -> None:
1122
+ self._test_op_upgrade(
1123
+ "ReverseSequence",
1124
+ 10,
1125
+ [[3, 4, 5], [4]],
1126
+ [[3, 4, 5]],
1127
+ [TensorProto.FLOAT, TensorProto.INT64],
1128
+ )
1129
+
1130
+ def test_RNN_1(self) -> None:
1131
+ # 6->7 adapter is missing
1132
+ self._test_op_upgrade(
1133
+ "RNN",
1134
+ 7,
1135
+ [[5, 3, 4], [1, 6, 4], [1, 6, 4]],
1136
+ [[5, 1, 3, 6], [1, 3, 6]],
1137
+ attrs={"hidden_size": 6},
1138
+ )
1139
+
1140
+ def test_RNN_2(self) -> None:
1141
+ # 6->7 adapter is missing
1142
+ self._test_op_upgrade(
1143
+ "RNN",
1144
+ 7,
1145
+ [[5, 3, 4], [2, 6, 4], [2, 6, 4]],
1146
+ [[5, 2, 3, 6], [2, 3, 6]],
1147
+ attrs={"hidden_size": 6, "direction": "bidirectional"},
1148
+ )
1149
+
1150
+ def test_RNN_3(self) -> None:
1151
+ # 6->7 adapter is missing
1152
+ self._test_op_upgrade(
1153
+ "RNN",
1154
+ 7,
1155
+ [[5, 3, 4], [1, 6, 4], [1, 6, 4], [1, 12], [5], [1, 5, 6]],
1156
+ [[5, 1, 3, 6], [1, 3, 6]],
1157
+ [
1158
+ TensorProto.FLOAT,
1159
+ TensorProto.FLOAT,
1160
+ TensorProto.FLOAT,
1161
+ TensorProto.FLOAT,
1162
+ TensorProto.INT64,
1163
+ TensorProto.FLOAT,
1164
+ ],
1165
+ attrs={"hidden_size": 6},
1166
+ )
1167
+
1168
+ def test_RoiAlign_1(self) -> None:
1169
+ self._test_op_upgrade(
1170
+ "RoiAlign",
1171
+ 10,
1172
+ [[2, 3, 20, 20], [10, 4], [10]],
1173
+ [[10, 3, 1, 1]],
1174
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.INT64],
1175
+ )
1176
+
1177
+ def test_RoiAlign_2(self) -> None:
1178
+ self._test_op_upgrade(
1179
+ "RoiAlign",
1180
+ 16,
1181
+ [[2, 3, 20, 20], [10, 4], [10]],
1182
+ [[10, 3, 1, 1]],
1183
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.INT64],
1184
+ attrs={"coordinate_transformation_mode": "half_pixel"},
1185
+ )
1186
+
1187
+ def test_RotaryEmbedding_1(self) -> None:
1188
+ self._test_op_upgrade(
1189
+ "RotaryEmbedding",
1190
+ 23,
1191
+ [[2, 4, 3, 8], [2, 3, 4], [2, 3, 4]],
1192
+ [[2, 4, 3, 8]],
1193
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
1194
+ [TensorProto.FLOAT],
1195
+ )
1196
+
1197
+ def test_RotaryEmbedding_2(self) -> None:
1198
+ self._test_op_upgrade(
1199
+ "RotaryEmbedding",
1200
+ 23,
1201
+ [[2, 4, 3, 8], [50, 4], [50, 4], [2, 3]],
1202
+ [[2, 4, 3, 8]],
1203
+ [
1204
+ TensorProto.FLOAT,
1205
+ TensorProto.FLOAT,
1206
+ TensorProto.FLOAT,
1207
+ TensorProto.INT64,
1208
+ ],
1209
+ [TensorProto.FLOAT],
1210
+ )
1211
+
1212
+ def test_RotaryEmbedding_3(self) -> None:
1213
+ self._test_op_upgrade(
1214
+ "RotaryEmbedding",
1215
+ 23,
1216
+ [[2, 3, 32], [50, 4], [50, 4], [2, 3]],
1217
+ [[2, 3, 32]],
1218
+ [
1219
+ TensorProto.FLOAT,
1220
+ TensorProto.FLOAT,
1221
+ TensorProto.FLOAT,
1222
+ TensorProto.INT64,
1223
+ ],
1224
+ [TensorProto.FLOAT],
1225
+ attrs={"num_heads": 4},
1226
+ )
1227
+
1228
+ def test_RotaryEmbedding_4(self) -> None:
1229
+ self._test_op_upgrade(
1230
+ "RotaryEmbedding",
1231
+ 23,
1232
+ [[2, 4, 3, 8], [50, 4], [50, 4], [2, 3]],
1233
+ [[2, 4, 3, 8]],
1234
+ [
1235
+ TensorProto.FLOAT,
1236
+ TensorProto.FLOAT,
1237
+ TensorProto.FLOAT,
1238
+ TensorProto.INT64,
1239
+ ],
1240
+ [TensorProto.FLOAT],
1241
+ attrs={"interleaved": 1},
1242
+ )
1243
+
1244
+ def test_RotaryEmbedding_5(self) -> None:
1245
+ self._test_op_upgrade(
1246
+ "RotaryEmbedding",
1247
+ 23,
1248
+ [[2, 4, 3, 8], [50, 4], [50, 4], [2, 3]],
1249
+ [[2, 4, 3, 8]],
1250
+ [
1251
+ TensorProto.FLOAT,
1252
+ TensorProto.FLOAT,
1253
+ TensorProto.FLOAT,
1254
+ TensorProto.INT64,
1255
+ ],
1256
+ [TensorProto.FLOAT],
1257
+ attrs={"rotary_embedding_dim": 4},
1258
+ )
1259
+
1260
+ def test_Round(self) -> None:
1261
+ self._test_op_upgrade("Round", 11)
1262
+
1263
+ def test_RMSNormalization(self) -> None:
1264
+ self._test_op_upgrade(
1265
+ "RMSNormalization",
1266
+ 23,
1267
+ [[2, 3, 4, 5], [4, 5]],
1268
+ [[2, 3, 4, 5]],
1269
+ input_types=[TensorProto.FLOAT, TensorProto.FLOAT],
1270
+ output_types=[TensorProto.FLOAT],
1271
+ attrs={"axis": 2},
1272
+ )
1273
+
1274
+ def test_Scatter(self) -> None:
1275
+ self._test_op_upgrade(
1276
+ "Scatter",
1277
+ 9,
1278
+ [[2, 3], [1, 2], [1, 2]],
1279
+ [[2, 3]],
1280
+ [TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
1281
+ [TensorProto.FLOAT],
1282
+ )
1283
+
1284
+ def test_ScatterElements_1(self) -> None:
1285
+ self._test_op_upgrade(
1286
+ "ScatterElements",
1287
+ 11,
1288
+ [[2, 3], [1, 2], [1, 2]],
1289
+ [[2, 3]],
1290
+ [TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
1291
+ [TensorProto.FLOAT],
1292
+ )
1293
+
1294
+ def test_ScatterElements_2(self) -> None:
1295
+ self._test_op_upgrade(
1296
+ "ScatterElements",
1297
+ 16,
1298
+ [[2, 3], [1, 2], [1, 2]],
1299
+ [[2, 3]],
1300
+ [TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
1301
+ [TensorProto.FLOAT],
1302
+ attrs={"reduction": "add"},
1303
+ )
1304
+
1305
+ def test_ScatterND_1(self) -> None:
1306
+ self._test_op_upgrade(
1307
+ "ScatterND",
1308
+ 11,
1309
+ [[2, 3], [1, 2], [1, 2]],
1310
+ [[2, 3]],
1311
+ [TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
1312
+ [TensorProto.FLOAT],
1313
+ )
1314
+
1315
+ def test_ScatterND_2(self) -> None:
1316
+ self._test_op_upgrade(
1317
+ "ScatterND",
1318
+ 16,
1319
+ [[2, 3], [1, 2], [1, 2]],
1320
+ [[2, 3]],
1321
+ [TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
1322
+ [TensorProto.FLOAT],
1323
+ attrs={"reduction": "mul"},
1324
+ )
1325
+
1326
+ def test_Scan(self) -> None:
1327
+ sum_in = onnx.helper.make_tensor_value_info(
1328
+ "sum_in", onnx.TensorProto.FLOAT, [2]
1329
+ )
1330
+ next_in = onnx.helper.make_tensor_value_info(
1331
+ "next_in", onnx.TensorProto.FLOAT, [2]
1332
+ )
1333
+ sum_out = onnx.helper.make_tensor_value_info(
1334
+ "sum_out", onnx.TensorProto.FLOAT, [2]
1335
+ )
1336
+ scan_out = onnx.helper.make_tensor_value_info(
1337
+ "scan_out", onnx.TensorProto.FLOAT, [2]
1338
+ )
1339
+ add_node = onnx.helper.make_node(
1340
+ "Add", inputs=["sum_in", "next_in"], outputs=["sum_out"]
1341
+ )
1342
+ id_node = onnx.helper.make_node(
1343
+ "Identity", inputs=["sum_out"], outputs=["scan_out"]
1344
+ )
1345
+ body = onnx.helper.make_graph(
1346
+ [add_node, id_node], "scan_body", [sum_in, next_in], [sum_out, scan_out]
1347
+ )
1348
+ self._test_op_upgrade(
1349
+ "Scan",
1350
+ 8,
1351
+ ["", [1, 2], [1, 3, 2]],
1352
+ [[1, 2], [1, 3, 2]],
1353
+ attrs={"body": body, "num_scan_inputs": 1},
1354
+ )
1355
+
1356
+ def test_Selu(self) -> None:
1357
+ self._test_op_upgrade("Selu", 1, attrs={"consumed_inputs": [0]})
1358
+
1359
+ def test_Shape(self) -> None:
1360
+ self._test_op_upgrade(
1361
+ "Shape", 1, [[3, 4, 5]], [[3]], output_types=[TensorProto.INT64]
1362
+ )
1363
+
1364
+ def test_Shrink(self) -> None:
1365
+ self._test_op_upgrade("Shrink", 9)
1366
+
1367
+ def test_Sigmoid(self) -> None:
1368
+ self._test_op_upgrade("Sigmoid", 1, attrs={"consumed_inputs": [0]})
1369
+
1370
+ def test_Sign(self) -> None:
1371
+ self._test_op_upgrade("Sign", 9)
1372
+
1373
+ def test_Sinh(self) -> None:
1374
+ self._test_op_upgrade("Sinh", 9)
1375
+
1376
+ def test_Sin(self) -> None:
1377
+ self._test_op_upgrade("Sin", 7)
1378
+
1379
+ def test_Size(self) -> None:
1380
+ self._test_op_upgrade(
1381
+ "Size", 1, [[3, 4, 5]], [[]], output_types=[TensorProto.INT64]
1382
+ )
1383
+
1384
+ def test_Slice(self) -> None:
1385
+ self._test_op_upgrade(
1386
+ "Slice",
1387
+ 1,
1388
+ [[3, 4, 5]],
1389
+ [[3, 2, 2]],
1390
+ attrs={"axes": [1, 2], "starts": [0, 1], "ends": [2, 3]},
1391
+ )
1392
+
1393
+ def test_Softmax_0(self) -> None:
1394
+ self._test_op_upgrade("Softmax", 1, attrs={"axis": 0})
1395
+
1396
+ def test_Softmax_1(self) -> None:
1397
+ self._test_op_upgrade("Softmax", 1, attrs={"axis": 1})
1398
+
1399
+ def test_Softmax_2(self) -> None:
1400
+ self._test_op_upgrade("Softmax", 1, attrs={"axis": 2})
1401
+
1402
+ def test_Softmax_3(self) -> None:
1403
+ self._test_op_upgrade("Softmax", 1, attrs={"axis": -1})
1404
+
1405
+ def test_Softmax_4(self) -> None:
1406
+ self._test_op_upgrade("Softmax", 1, attrs={"axis": -2})
1407
+
1408
+ def test_Softmax_5(self) -> None:
1409
+ self._test_op_upgrade("Softmax", 1, attrs={"axis": -3})
1410
+
1411
+ def test_Softplus(self) -> None:
1412
+ self._test_op_upgrade("Softplus", 1)
1413
+
1414
+ def test_Softsign(self) -> None:
1415
+ self._test_op_upgrade("Softsign", 1)
1416
+
1417
+ def test_SoftmaxCrossEntropyLoss(self) -> None:
1418
+ self._test_op_upgrade(
1419
+ "SoftmaxCrossEntropyLoss",
1420
+ 12,
1421
+ [[3, 4, 5, 6], [3, 6]],
1422
+ [[]],
1423
+ [TensorProto.FLOAT, TensorProto.INT64],
1424
+ )
1425
+
1426
+ def test_SpaceToDepth(self) -> None:
1427
+ self._test_op_upgrade(
1428
+ "SpaceToDepth", 1, [[1, 3, 8, 8]], [[1, 12, 4, 4]], attrs={"blocksize": 2}
1429
+ )
1430
+
1431
+ def test_Split(self) -> None:
1432
+ # 1->2 adapter is missing
1433
+ self._test_op_upgrade(
1434
+ "Split",
1435
+ 2,
1436
+ [[3, 4, 7]],
1437
+ [[3, 4, 2], [3, 4, 1], [3, 4, 4]],
1438
+ attrs={"axis": 2, "split": [2, 1, 4]},
1439
+ )
1440
+
1441
+ def test_Sqrt(self) -> None:
1442
+ self._test_op_upgrade("Sqrt", 1, attrs={"consumed_inputs": [0]})
1443
+
1444
+ def test_Squeeze(self) -> None:
1445
+ self._test_op_upgrade("Squeeze", 1, [[2, 1, 3, 4, 1]], [[2, 3, 4]])
1446
+
1447
+ def test_StringNormalizer(self) -> None:
1448
+ self._test_op_upgrade(
1449
+ "StringNormalizer",
1450
+ 10,
1451
+ [[1, 3]],
1452
+ [[1, 3]],
1453
+ [TensorProto.STRING],
1454
+ [TensorProto.STRING],
1455
+ attrs={"case_change_action": "LOWER"},
1456
+ )
1457
+
1458
+ def test_Sub(self) -> None:
1459
+ self._test_op_upgrade(
1460
+ "Sub",
1461
+ 1,
1462
+ [[2, 3, 4], [2, 3, 4]],
1463
+ [[2, 3, 4]],
1464
+ attrs={"consumed_inputs": [0]},
1465
+ )
1466
+
1467
+ def test_Sum(self) -> None:
1468
+ self._test_op_upgrade(
1469
+ "Sum",
1470
+ 1,
1471
+ [[2, 3, 4], [2, 3, 4]],
1472
+ [[2, 3, 4]],
1473
+ attrs={"consumed_inputs": [0]},
1474
+ )
1475
+
1476
+ def test_Swish(self) -> None:
1477
+ self._test_op_upgrade("Swish", 24, attrs={"alpha": 0.2})
1478
+
1479
+ def test_Tanh(self) -> None:
1480
+ self._test_op_upgrade("Tanh", 1, attrs={"consumed_inputs": [0]})
1481
+
1482
+ def test_Tan(self) -> None:
1483
+ self._test_op_upgrade("Tan", 7)
1484
+
1485
+ def test_TfIdfVectorizer(self) -> None:
1486
+ self._test_op_upgrade(
1487
+ "TfIdfVectorizer",
1488
+ 9,
1489
+ [[3]],
1490
+ [[5]],
1491
+ attrs={
1492
+ "max_gram_length": 3,
1493
+ "max_skip_count": 1,
1494
+ "min_gram_length": 2,
1495
+ "mode": "TFIDF",
1496
+ "ngram_counts": [0, 20],
1497
+ "ngram_indexes": [3, 4],
1498
+ },
1499
+ )
1500
+
1501
+ def test_ThresholdedRelu(self) -> None:
1502
+ self._test_op_upgrade("ThresholdedRelu", 10)
1503
+
1504
+ def test_Tile(self) -> None:
1505
+ # 5->6 adapter is missing
1506
+ repeats = helper.make_tensor(
1507
+ "b", TensorProto.INT64, dims=[3], vals=np.array([1, 2, 3])
1508
+ )
1509
+ self._test_op_upgrade(
1510
+ "Tile",
1511
+ 6,
1512
+ [[3, 4, 5], [3]],
1513
+ [[3, 8, 15]],
1514
+ [TensorProto.FLOAT, TensorProto.INT64],
1515
+ initializer=[repeats],
1516
+ )
1517
+
1518
+ def test_TopK(self) -> None:
1519
+ self._test_op_upgrade(
1520
+ "TopK",
1521
+ 1,
1522
+ [[3, 4, 5]],
1523
+ [[3, 4, 2], [3, 4, 2]],
1524
+ output_types=[TensorProto.FLOAT, TensorProto.INT64],
1525
+ attrs={"k": 2},
1526
+ )
1527
+
1528
+ def test_Transpose(self) -> None:
1529
+ self._test_op_upgrade(
1530
+ "Transpose",
1531
+ 1,
1532
+ [[1, 2, 5, 3, 7]],
1533
+ [[1, 7, 5, 2, 3]],
1534
+ attrs={"perm": [0, 4, 2, 1, 3]},
1535
+ )
1536
+
1537
+ def test_Trilu(self) -> None:
1538
+ self._test_op_upgrade("Trilu", 14)
1539
+
1540
+ def test_Unique_1(self) -> None:
1541
+ self._test_op_upgrade("Unique", 11, [[3, 4, 5]], [[None]])
1542
+
1543
+ def test_Unique_2(self) -> None:
1544
+ self._test_op_upgrade(
1545
+ "Unique", 11, [[3, 4, 5]], [[3, None, 5]], attrs={"axis": 1}
1546
+ )
1547
+
1548
+ def test_Unsqueeze(self) -> None:
1549
+ self._test_op_upgrade(
1550
+ "Unsqueeze", 1, [[3, 4, 5]], [[3, 4, 1, 5]], attrs={"axes": [2]}
1551
+ )
1552
+
1553
+ def test_Upsample(self) -> None:
1554
+ self._test_op_upgrade(
1555
+ "Upsample",
1556
+ 1,
1557
+ [[1, 3, 4, 5]],
1558
+ [[1, 3, 6, 10]],
1559
+ attrs={"width_scale": 2.0, "height_scale": 1.5},
1560
+ )
1561
+
1562
+ def test_Where(self) -> None:
1563
+ self._test_op_upgrade(
1564
+ "Where",
1565
+ 9,
1566
+ [[2, 3], [2, 3], [2, 3]],
1567
+ [[2, 3]],
1568
+ [TensorProto.BOOL, TensorProto.FLOAT, TensorProto.FLOAT],
1569
+ )
1570
+
1571
+ def test_Xor(self) -> None:
1572
+ # 6->7 adapter is missing
1573
+ self._test_op_upgrade(
1574
+ "Xor",
1575
+ 7,
1576
+ [[2, 3], [2, 3]],
1577
+ [[2, 3]],
1578
+ [TensorProto.BOOL, TensorProto.BOOL],
1579
+ [TensorProto.BOOL],
1580
+ )
1581
+
1582
+ def test_CastLike(self) -> None:
1583
+ self._test_op_upgrade(
1584
+ "CastLike",
1585
+ 15,
1586
+ [[2, 3, 4], [2, 1, 4]],
1587
+ [[2, 3, 4]],
1588
+ input_types=[TensorProto.FLOAT, TensorProto.FLOAT16],
1589
+ output_types=[TensorProto.FLOAT16],
1590
+ )
1591
+
1592
+ def test_LayerNormalization(self) -> None:
1593
+ self._test_op_upgrade(
1594
+ "LayerNormalization",
1595
+ 17,
1596
+ [[2, 3, 4, 5], [4, 5], [4, 5]],
1597
+ [[2, 3, 4, 5]],
1598
+ input_types=[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
1599
+ output_types=[TensorProto.FLOAT],
1600
+ attrs={"axis": 2},
1601
+ )
1602
+
1603
+ def _test_window_function(self, window_function_name: str) -> None:
1604
+ size = helper.make_tensor("a", TensorProto.INT64, dims=[], vals=np.array([10]))
1605
+ self._test_op_upgrade(
1606
+ window_function_name,
1607
+ 17,
1608
+ [[]],
1609
+ [[10]],
1610
+ [TensorProto.INT64],
1611
+ initializer=[size],
1612
+ )
1613
+
1614
+ def test_BlackmanWindow(self) -> None:
1615
+ self._test_window_function("BlackmanWindow")
1616
+
1617
+ def test_HannWindow(self) -> None:
1618
+ self._test_window_function("HannWindow")
1619
+
1620
+ def test_HammingWindow(self) -> None:
1621
+ self._test_window_function("HammingWindow")
1622
+
1623
+ def test_DFT(self) -> None:
1624
+ self._test_op_upgrade("DFT", 17, [[2, 16, 1], []], [[2, 16, 2]])
1625
+ self._test_op_upgrade("DFT", 17, [[2, 16, 2], []], [[2, 16, 2]])
1626
+ self._test_op_upgrade(
1627
+ "DFT", 17, [[2, 16, 1], []], [[2, 9, 2]], attrs={"onesided": 1}
1628
+ )
1629
+ self._test_op_upgrade(
1630
+ "DFT", 17, [[2, 16, 2], []], [[2, 9, 2]], attrs={"onesided": 1}
1631
+ )
1632
+ self._test_op_upgrade(
1633
+ "DFT", 17, [[2, 16, 1], []], [[2, 16, 2]], attrs={"inverse": 1}
1634
+ )
1635
+ self._test_op_upgrade(
1636
+ "DFT", 17, [[2, 16, 2], []], [[2, 16, 2]], attrs={"inverse": 1}
1637
+ )
1638
+ self._test_op_upgrade(
1639
+ "DFT", 17, [[2, 16, 2], []], [[2, 16, 2]], attrs={"inverse": 1, "axis": 0}
1640
+ )
1641
+
1642
+ def _test_short_time_fourier_transform(self, operator_name: str) -> None:
1643
+ # Real
1644
+ signal = helper.make_tensor(
1645
+ "a",
1646
+ TensorProto.FLOAT,
1647
+ dims=[2, 64],
1648
+ vals=np.random.rand(2, 64).astype(np.float32),
1649
+ )
1650
+ frame_step = helper.make_tensor(
1651
+ "b", TensorProto.INT64, dims=[1], vals=np.array([8])
1652
+ )
1653
+ window = helper.make_tensor(
1654
+ "c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
1655
+ )
1656
+ self._test_op_upgrade(
1657
+ operator_name,
1658
+ 17,
1659
+ [[2, 64], [1], [16]],
1660
+ [[2, 7, 16, 2]],
1661
+ [
1662
+ TensorProto.FLOAT,
1663
+ TensorProto.INT64,
1664
+ TensorProto.FLOAT,
1665
+ TensorProto.INT64,
1666
+ ],
1667
+ initializer=[signal, frame_step, window],
1668
+ )
1669
+
1670
+ # Real Onesided
1671
+ signal = helper.make_tensor(
1672
+ "a",
1673
+ TensorProto.FLOAT,
1674
+ dims=[2, 64],
1675
+ vals=np.random.rand(2, 64).astype(np.float32),
1676
+ )
1677
+ frame_step = helper.make_tensor(
1678
+ "b", TensorProto.INT64, dims=[1], vals=np.array([8])
1679
+ )
1680
+ window = helper.make_tensor(
1681
+ "c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
1682
+ )
1683
+ self._test_op_upgrade(
1684
+ operator_name,
1685
+ 17,
1686
+ [[2, 64], [1], [16]],
1687
+ [[2, 7, 9, 2]],
1688
+ [
1689
+ TensorProto.FLOAT,
1690
+ TensorProto.INT64,
1691
+ TensorProto.FLOAT,
1692
+ TensorProto.INT64,
1693
+ ],
1694
+ attrs={"onesided": 1},
1695
+ initializer=[signal, frame_step, window],
1696
+ )
1697
+
1698
+ # Complex
1699
+ signal = helper.make_tensor(
1700
+ "a",
1701
+ TensorProto.FLOAT,
1702
+ dims=[2, 64, 2],
1703
+ vals=np.random.rand(2, 64, 2).astype(np.float32),
1704
+ )
1705
+ frame_step = helper.make_tensor(
1706
+ "b", TensorProto.INT64, dims=[1], vals=np.array([8])
1707
+ )
1708
+ window = helper.make_tensor(
1709
+ "c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
1710
+ )
1711
+ self._test_op_upgrade(
1712
+ operator_name,
1713
+ 17,
1714
+ [[2, 64, 2], [1], [16]],
1715
+ [[2, 7, 16, 2]],
1716
+ [
1717
+ TensorProto.FLOAT,
1718
+ TensorProto.INT64,
1719
+ TensorProto.FLOAT,
1720
+ TensorProto.INT64,
1721
+ ],
1722
+ initializer=[signal, frame_step, window],
1723
+ )
1724
+
1725
+ # Complex Onesided
1726
+ signal = helper.make_tensor(
1727
+ "a",
1728
+ TensorProto.FLOAT,
1729
+ dims=[2, 64, 2],
1730
+ vals=np.random.rand(2, 64, 2).astype(np.float32),
1731
+ )
1732
+ frame_step = helper.make_tensor(
1733
+ "b", TensorProto.INT64, dims=[1], vals=np.array([8])
1734
+ )
1735
+ window = helper.make_tensor(
1736
+ "c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
1737
+ )
1738
+ frame_length = helper.make_tensor(
1739
+ "e", TensorProto.INT64, dims=[1], vals=np.array([16])
1740
+ )
1741
+ self._test_op_upgrade(
1742
+ operator_name,
1743
+ 17,
1744
+ [[2, 64, 2], [1], [16]],
1745
+ [[2, 7, 9, 2]],
1746
+ [
1747
+ TensorProto.FLOAT,
1748
+ TensorProto.INT64,
1749
+ TensorProto.FLOAT,
1750
+ TensorProto.INT64,
1751
+ ],
1752
+ attrs={"onesided": 1},
1753
+ initializer=[signal, frame_step, window, frame_length],
1754
+ )
1755
+
1756
+ def test_STFT(self) -> None:
1757
+ self._test_short_time_fourier_transform("STFT")
1758
+
1759
+ def test_MelWeightMatrix(self) -> None:
1760
+ num_mel_bins = helper.make_tensor(
1761
+ "a", TensorProto.INT64, dims=[], vals=np.array([10])
1762
+ )
1763
+ dft_length = helper.make_tensor(
1764
+ "b", TensorProto.INT64, dims=[], vals=np.array([64])
1765
+ )
1766
+ sample_rate = helper.make_tensor(
1767
+ "c", TensorProto.INT64, dims=[], vals=np.array([0])
1768
+ )
1769
+ lower_edge_hertz = helper.make_tensor(
1770
+ "d", TensorProto.FLOAT, dims=[], vals=np.array([0])
1771
+ )
1772
+ upper_edge_hertz = helper.make_tensor(
1773
+ "e", TensorProto.FLOAT, dims=[], vals=np.array([1])
1774
+ )
1775
+
1776
+ self._test_op_upgrade(
1777
+ "MelWeightMatrix",
1778
+ 17,
1779
+ [[], [], [], [], []],
1780
+ [[33, 10]],
1781
+ [
1782
+ TensorProto.INT64,
1783
+ TensorProto.INT64,
1784
+ TensorProto.INT64,
1785
+ TensorProto.FLOAT,
1786
+ TensorProto.FLOAT,
1787
+ ],
1788
+ initializer=[
1789
+ num_mel_bins,
1790
+ dft_length,
1791
+ sample_rate,
1792
+ lower_edge_hertz,
1793
+ upper_edge_hertz,
1794
+ ],
1795
+ )
1796
+
1797
+ num_mel_bins = helper.make_tensor(
1798
+ "a", TensorProto.INT64, dims=[], vals=np.array([20])
1799
+ )
1800
+ dft_length = helper.make_tensor(
1801
+ "b", TensorProto.INT64, dims=[], vals=np.array([31])
1802
+ )
1803
+ sample_rate = helper.make_tensor(
1804
+ "c", TensorProto.INT64, dims=[], vals=np.array([0])
1805
+ )
1806
+ lower_edge_hertz = helper.make_tensor(
1807
+ "d", TensorProto.FLOAT, dims=[], vals=np.array([0])
1808
+ )
1809
+ upper_edge_hertz = helper.make_tensor(
1810
+ "e", TensorProto.FLOAT, dims=[], vals=np.array([1])
1811
+ )
1812
+
1813
+ self._test_op_upgrade(
1814
+ "MelWeightMatrix",
1815
+ 17,
1816
+ [[], [], [], [], []],
1817
+ [[16, 20]],
1818
+ [
1819
+ TensorProto.INT64,
1820
+ TensorProto.INT64,
1821
+ TensorProto.INT64,
1822
+ TensorProto.FLOAT,
1823
+ TensorProto.FLOAT,
1824
+ ],
1825
+ initializer=[
1826
+ num_mel_bins,
1827
+ dft_length,
1828
+ sample_rate,
1829
+ lower_edge_hertz,
1830
+ upper_edge_hertz,
1831
+ ],
1832
+ )
1833
+
1834
+ def test_CenterCropPad(self) -> None:
1835
+ input_ = helper.make_tensor(
1836
+ "input",
1837
+ TensorProto.FLOAT,
1838
+ dims=[2, 4],
1839
+ vals=np.array([1, 2, 3, 4, 5, 6, 7, 8]),
1840
+ )
1841
+ shape = helper.make_tensor(
1842
+ "shape", TensorProto.INT64, dims=[2], vals=np.array([3, 3])
1843
+ )
1844
+ self._test_op_upgrade(
1845
+ "CenterCropPad",
1846
+ 18,
1847
+ [[], []],
1848
+ [[3, 3]],
1849
+ [TensorProto.FLOAT, TensorProto.INT64],
1850
+ initializer=[input_, shape],
1851
+ )
1852
+
1853
+ def test_BitwiseNot(self) -> None:
1854
+ self._test_op_upgrade(
1855
+ "BitwiseNot",
1856
+ 18,
1857
+ [[2, 3]],
1858
+ [[2, 3]],
1859
+ [TensorProto.INT32],
1860
+ [TensorProto.INT32],
1861
+ )
1862
+
1863
+ def test_BitwiseAnd(self) -> None:
1864
+ self._test_op_upgrade(
1865
+ "BitwiseAnd",
1866
+ 18,
1867
+ [[2, 3], [2, 3]],
1868
+ [[2, 3]],
1869
+ [TensorProto.INT16, TensorProto.INT16],
1870
+ [TensorProto.INT16],
1871
+ )
1872
+
1873
+ def test_BitwiseOr(self) -> None:
1874
+ self._test_op_upgrade(
1875
+ "BitwiseOr",
1876
+ 18,
1877
+ [[2, 3], [2, 3]],
1878
+ [[2, 3]],
1879
+ [TensorProto.INT16, TensorProto.INT16],
1880
+ [TensorProto.INT16],
1881
+ )
1882
+
1883
+ def test_BitwiseXor(self) -> None:
1884
+ self._test_op_upgrade(
1885
+ "BitwiseXor",
1886
+ 18,
1887
+ [[2, 3], [2, 3]],
1888
+ [[2, 3]],
1889
+ [TensorProto.INT16, TensorProto.INT16],
1890
+ [TensorProto.INT16],
1891
+ )
1892
+
1893
+ def test_GroupNormalization(self) -> None:
1894
+ self._test_op_upgrade(
1895
+ "GroupNormalization",
1896
+ 21,
1897
+ [[3, 4, 2, 2], [4], [4]],
1898
+ [[3, 4, 2, 2]],
1899
+ attrs={"epsilon": 1e-5, "num_groups": 2},
1900
+ )
1901
+
1902
+ def test_StringConcat(self) -> None:
1903
+ self._test_op_upgrade(
1904
+ "StringConcat",
1905
+ 20,
1906
+ [[2, 3], [2, 3]],
1907
+ [[2, 3]],
1908
+ )
1909
+
1910
+ def test_RegexFullMatch(self) -> None:
1911
+ self._test_op_upgrade(
1912
+ "RegexFullMatch",
1913
+ 20,
1914
+ [[2, 3]],
1915
+ [[2, 3]],
1916
+ [TensorProto.STRING],
1917
+ [TensorProto.BOOL],
1918
+ )
1919
+
1920
+ def test_TensorScatter(self) -> None:
1921
+ self._test_op_upgrade(
1922
+ "TensorScatter",
1923
+ 24,
1924
+ [
1925
+ [2, 3, 4, 5],
1926
+ [2, 3, 2, 5],
1927
+ [
1928
+ 2,
1929
+ ],
1930
+ ],
1931
+ [[2, 3, 4, 5]],
1932
+ [TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.INT64],
1933
+ [TensorProto.FLOAT],
1934
+ )
1935
+
1936
+ def test_ops_tested(self) -> None:
1937
+ # NOTE: This test is order dependent and needs to run last in this class
1938
+ all_schemas = onnx.defs.get_all_schemas()
1939
+ all_op_names = {schema.name for schema in all_schemas if schema.domain == ""}
1940
+ excluded_ops = {
1941
+ # Sequence-based and Optional-based ops disabled because
1942
+ # the version converter doesn't play nicely with sequences
1943
+ "ConcatFromSequence",
1944
+ "SequenceAt",
1945
+ "SequenceConstruct",
1946
+ "SequenceEmpty",
1947
+ "SequenceErase",
1948
+ "SequenceInsert",
1949
+ "SequenceLength",
1950
+ "SequenceMap",
1951
+ "SplitToSequence",
1952
+ "Optional",
1953
+ "OptionalGetElement",
1954
+ "OptionalHasElement",
1955
+ "StringSplit",
1956
+ }
1957
+ expected_tested_ops = all_op_names - excluded_ops
1958
+
1959
+ untested_ops = expected_tested_ops - set(self.tested_ops)
1960
+ self.assertEqual(untested_ops, set())
1961
+
1962
+
1963
+ if __name__ == "__main__":
1964
+ unittest.main()
pythonProject/.venv/Lib/site-packages/onnxscript/converter.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import logging
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Dict,
11
+ List,
12
+ NoReturn,
13
+ Optional,
14
+ Sequence,
15
+ Tuple,
16
+ Union,
17
+ )
18
+
19
+ import onnx
20
+
21
+ import onnxscript
22
+ from onnxscript import irbuilder, onnx_types, sourceinfo, values
23
+ from onnxscript import type_annotation as ta
24
+ from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation
25
+
26
+ logger = logging.getLogger("onnxscript")
27
+
28
+
29
+ # Python-to-IR converter:
30
+
31
+
32
+ def not_allowed(construct):
33
+ return f"{construct}not supported."
34
+
35
+
36
+ class TranslationError(Exception):
37
+ def __init__(self, *args: object) -> None:
38
+ super().__init__(*args)
39
+
40
+
41
+ def warn(msg):
42
+ logger.warning(msg)
43
+
44
+
45
+ def fail(msg) -> NoReturn:
46
+ raise TranslationError(msg)
47
+
48
+
49
+ def fail_if(cond, msg):
50
+ if cond:
51
+ raise TranslationError(msg)
52
+
53
+
54
+ def ignore(cond, msg):
55
+ if cond:
56
+ warn(msg)
57
+
58
+
59
+ # map from python operators to ONNX ops
60
+ primop_map = {
61
+ ast.Add: "Add",
62
+ ast.And: "And",
63
+ ast.BitAnd: "And",
64
+ ast.BitOr: "Or",
65
+ ast.Div: "Div",
66
+ ast.Eq: "Equal",
67
+ ast.Gt: "Greater",
68
+ ast.GtE: "GreaterOrEqual",
69
+ ast.Lt: "Less",
70
+ ast.LtE: "LessOrEqual",
71
+ ast.MatMult: "MatMul",
72
+ ast.Mod: "Mod",
73
+ ast.Mult: "Mul",
74
+ ast.Not: "Not",
75
+ ast.NotEq: "NotEqual",
76
+ ast.Or: "Or",
77
+ ast.Pow: "Pow",
78
+ ast.Sub: "Sub",
79
+ ast.USub: "Neg",
80
+ }
81
+
82
+
83
+ class Variable:
84
+ """Represents an ONNX variable.
85
+
86
+ TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this
87
+ converter.
88
+ """
89
+
90
+ def __init__(self, name: str, castable: bool = False):
91
+ """Initialize the instance.
92
+
93
+ Args:
94
+ name: Name of the ONNX variable
95
+ castable: Whether this variable is castable to a desired target type.
96
+ Used for ONNX variables representing constants created from python values
97
+ like 0 or 1 or 0.5 which are treated as polymorphic values castable to other
98
+ types as needed.
99
+ """
100
+ self.name = name
101
+ self.is_castable = castable
102
+
103
+ def __str__(self) -> str:
104
+ return self.name
105
+
106
+
107
+ if TYPE_CHECKING:
108
+ # The type-alias LocalSymValue represents the types of values that local names in a
109
+ # script-function may be bound to during translation, (ONNX IR values).
110
+ # TODO(rama): Rationalize this and values.SymbolValue
111
+
112
+ LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction]
113
+
114
+ # The type-alias PyValue is used to represent the types of python values that may be used
115
+ # in an ONNX Script function.
116
+ # TODO(rama): Flesh out the set of valid types here. These include values such as
117
+ # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for
118
+ # use as value-parameters or attribute-parameters in an ONNX call (Node).
119
+
120
+ PyValue = Any
121
+
122
+ # The type-alias SymValue denotes values that an identifier may be bound to during
123
+ # translation. A local name will be bound to a LocalSymValue, while a global name
124
+ # will be bound to a PyValue.
125
+
126
+ SymValue = Union[LocalSymValue, PyValue]
127
+
128
+ # PreferredName is a type-alias used to represent the preferred name used in the generated
129
+ # ONNX for a value returned by an expression. There is no guarantee that the specified
130
+ # name will be used exactly. The converter will modify the name (with a suffix),
131
+ # if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement).
132
+
133
+ PreferredName = str
134
+
135
+ # The type-alias OnnxVar indicates variable names used in the generated ONNX.
136
+ OnnxVarName = str
137
+
138
+
139
+ class Converter:
140
+ """Main class to translate python code into ONNX operators.
141
+
142
+ Args:
143
+ ir_builder: convert AST node into ONNX structures, if None,
144
+ class :class:`onnxscript.irbuilder.IRBuilder` is used
145
+
146
+ The class uses logger `onnxscript`. Logging can be enabled with the following code:
147
+
148
+ ::
149
+
150
+ import logging
151
+ logging.basicConfig(level=logging.DEBUG)
152
+
153
+ Or if you need to enable only the logger used by this module:
154
+
155
+ ::
156
+
157
+ import logging
158
+ logger = logging.getLogger('onnxscript')
159
+ logger.setLevel(logging.DEBUG)
160
+ console = logging.StreamHandler()
161
+ logger.addHandler(console)
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ ir_builder: Optional[irbuilder.IRBuilder] = None,
167
+ opset: Optional[values.Opset] = None,
168
+ global_names: Optional[dict[str, Any]] = None,
169
+ source: Optional[str] = None,
170
+ default_opset: Optional[values.Opset] = None,
171
+ ):
172
+ self.ir_builder = ir_builder or irbuilder.IRBuilder()
173
+ self.source = source
174
+ if global_names is not None:
175
+ # We make a copy in case function eval modifies it.
176
+ self.globals = global_names.copy()
177
+ self.this_module = opset
178
+ self.default_opset_ = default_opset
179
+
180
+ # States initialized by `_init_function_translation`
181
+ self._outer: List[irbuilder.IRFunction] = []
182
+ self._current_fn: irbuilder.IRFunction = None
183
+ self._nextvar: int = 0
184
+ self._used_vars: set[str] = set()
185
+ self._locals: List[Dict[str, LocalSymValue]] = [{}]
186
+
187
+ @property
188
+ def default_opset(self) -> values.Opset:
189
+ if self.default_opset_ is None:
190
+ raise RuntimeError(
191
+ "default_opset must be specified in script for functions "
192
+ "that do not contain any use of an ONNX opset."
193
+ )
194
+ return self.default_opset_
195
+
196
+ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
197
+ if opset.domain != "":
198
+ return
199
+ if self.default_opset_ is not None:
200
+ if (
201
+ opset.domain != self.default_opset_.domain
202
+ or opset.version != self.default_opset_.version
203
+ ):
204
+ self.fail(
205
+ node, f"Two distincts opset were used ({opset} != {self.default_opset_})."
206
+ )
207
+ else:
208
+ self.default_opset_ = opset
209
+
210
+ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
211
+ """Find the (first) ONNX opset used in the function, if any."""
212
+ # Search for a Call expression of form "op.OpName(...)"
213
+ if isinstance(node, ast.Call):
214
+ if isinstance(node.func, ast.Attribute):
215
+ opset_expr = node.func.value
216
+ if isinstance(opset_expr, ast.Name):
217
+ if opset_expr.id in self.globals:
218
+ opset = self.globals[opset_expr.id]
219
+ if isinstance(opset, values.Opset) and opset.domain == "":
220
+ return opset
221
+ for child in ast.iter_child_nodes(node):
222
+ res = self._find_onnx_opset(child)
223
+ if res is not None:
224
+ return res
225
+ return None
226
+
227
+ def _init_function_translation(self) -> None:
228
+ """Initialize self for translating a new (top-level) function."""
229
+ self._outer = []
230
+ self._current_fn: Optional[irbuilder.IRFunction] = None
231
+ self._nextvar = 0
232
+ self._used_vars = set()
233
+ self._locals: List[Dict[str, LocalSymValue]] = [{}]
234
+
235
+ def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
236
+ return sourceinfo.SourceInfo(node, self.source, self._current_fn.name)
237
+
238
+ def _message(self, node: ast.AST, error_msg: str) -> str:
239
+ """Constructs an error _message containing source information about an ast node."""
240
+ return self._source_of(node).msg(error_msg)
241
+
242
+ def warn(self, node: ast.AST, error_msg: str) -> None:
243
+ warn(self._message(node, error_msg))
244
+
245
+ def fail(self, node: ast.AST, error_msg: str) -> NoReturn:
246
+ fail(self._message(node, error_msg))
247
+
248
+ # Name resolution and namescopes: This component handles the following aspects:
249
+ # * Name-scopes are different in Python and the generated ONNX:
250
+ # - Control-flow blocks (a loop body or the then-or-else block of an if-stmt)
251
+ # form part of the same name-scope in Python, but will be mapped to a nested
252
+ # name-scope (as a sub-graph) in ONNX.
253
+ # * Script-time name-value tracking: Name _lookup during script-time returns
254
+ # statically-known information about the value the name will have at runtime.
255
+ def _enter_scope(self, name: str, parent_node: ast.AST):
256
+ """Enter a control-flow block (a loop body or if-then-else branch).
257
+ The block is translated into a nested-scope in ONNX.
258
+ """
259
+ self._outer.insert(0, self._current_fn)
260
+ self._current_fn = self.ir_builder.new_function(name)
261
+ self._locals.insert(0, {})
262
+ logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node))
263
+
264
+ def _exit_scope(self) -> irbuilder.IRFunction:
265
+ """Exit from a control-flow block (a loop body or if-then-else branch)."""
266
+ logger.debug("Converter:_exit_scope:%d", len(self._locals))
267
+ graph = self._current_fn
268
+ self._current_fn = self._outer.pop(0)
269
+ self._locals.pop(0)
270
+ return graph
271
+
272
+ def _current_scope(self) -> Dict[str, LocalSymValue]:
273
+ return self._locals[0]
274
+
275
+ def _bind(self, name: str, val: LocalSymValue) -> None:
276
+ logger.debug("Converter:_bind:%s", name)
277
+ self._locals[0][name] = val
278
+
279
+ def _lookup(
280
+ self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True
281
+ ) -> SymValue:
282
+ for scope in self._locals:
283
+ if name in scope:
284
+ return scope[name]
285
+ if name in self.globals:
286
+ return self.globals[name]
287
+ if raise_exception:
288
+ raise ValueError(info.msg(f"Unbound name: {name}."))
289
+ return None
290
+
291
+ def generate_unique_name(self, candidate: str = "tmp") -> str:
292
+ # TODO(justinchuby): Can we reduce the O complexity of this function?
293
+ r = candidate
294
+ while r in self._used_vars:
295
+ r = f"{candidate}_{self._nextvar}"
296
+ self._nextvar = self._nextvar + 1
297
+ self._used_vars.add(r)
298
+ return r
299
+
300
+ def _make_onnx_attr(
301
+ self, attrname: str, attrval: Any, attrtype: int | None = None
302
+ ) -> irbuilder.IRAttributeValue:
303
+ def tensor_name_generator() -> str:
304
+ """Return name to be used for tensor, if we need to create one."""
305
+ return self.generate_unique_name(f"attr_{attrname}")
306
+
307
+ proto = autocast.pyvalue_to_onnx_attribute(
308
+ attrname, attrval, tensor_name_generator, attrtype
309
+ )
310
+ return self.ir_builder.make_attr(proto)
311
+
312
+ def _to_onnx_attr_ref(
313
+ self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]
314
+ ) -> irbuilder.IRAttributeValue:
315
+ pytype = val.typeinfo
316
+ attrtype = ta.pytype_to_attrtype(pytype)
317
+ attrname = None
318
+ if attrtype is onnx.AttributeProto.FLOAT:
319
+ attrname = "value_float"
320
+ elif attrtype is onnx.AttributeProto.INT:
321
+ attrname = "value_int"
322
+ elif attrtype is onnx.AttributeProto.STRING:
323
+ attrname = "value_string"
324
+ elif attrtype is onnx.AttributeProto.INTS:
325
+ attrname = "value_ints"
326
+ else:
327
+ msg = f"Unsupported attribute type {pytype!r}."
328
+ fail(info.msg(msg) if info else msg)
329
+ return self.ir_builder.make_attr_ref(attrname, val.value, pytype)
330
+
331
+ def _to_onnx_var(
332
+ self,
333
+ val: values.SymbolValue | PyValue,
334
+ target: Optional[PreferredName] = None,
335
+ info: Optional[sourceinfo.SourceInfo] = None,
336
+ ) -> Variable:
337
+ if isinstance(val, values.AttrRef):
338
+ # promote attribute to value
339
+ result = self.generate_unique_name(target or "tmp")
340
+ attr = self._to_onnx_attr_ref(val, info)
341
+ self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr])
342
+ if ta.base_type_is_bool(val.typeinfo):
343
+ # ONNX attributes use an int-encoding for bools, but ONNX tensor types
344
+ # distinguish between int and bool. So we cast the int tensor to a bool tensor,
345
+ # to promote a (python) bool attribute to a ONNX bool tensor.
346
+ result_as_bool = self.generate_unique_name(result + "_as_bool")
347
+ cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype)
348
+ self.emit(
349
+ [result_as_bool],
350
+ values.Op(self.default_opset, "Cast"),
351
+ [result],
352
+ [cast_attr],
353
+ )
354
+ return Variable(result_as_bool, True)
355
+ return Variable(result, True)
356
+ if isinstance(val, values.Dynamic):
357
+ return Variable(val.value)
358
+ # Assume value is a python-value convertible to a tensor
359
+ # TODO: check if value is convertible to a TensorProto, so that we can
360
+ # produce a better error _message otherwise
361
+ return self._emit_const(val, target or "tmp", info)
362
+
363
+ def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable:
364
+ return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info)
365
+
366
+ def emit(
367
+ self,
368
+ outputs: Sequence[str],
369
+ callee: values.Op | str,
370
+ inputs: Sequence[Optional[str]],
371
+ attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None,
372
+ sub_functions: Optional[dict[str, onnx.FunctionProto]] = None,
373
+ ):
374
+ if not isinstance(callee, values.Op):
375
+ callee = values.Op(self.default_opset, callee)
376
+ if attrs is None:
377
+ attrs = []
378
+ if sub_functions is None:
379
+ sub_functions = {}
380
+ self.ir_builder.add_stmt(
381
+ self._current_fn,
382
+ outputs,
383
+ callee,
384
+ inputs,
385
+ attrs,
386
+ sub_functions,
387
+ )
388
+
389
+ def _emit_const(
390
+ self,
391
+ pyvalue: PyValue,
392
+ suggested_name: Optional[PreferredName],
393
+ info: sourceinfo.SourceInfo,
394
+ ) -> Variable:
395
+ if suggested_name is None:
396
+ if isinstance(pyvalue, int):
397
+ if pyvalue >= 0:
398
+ suggested_name = f"int64_{pyvalue}"
399
+ else:
400
+ suggested_name = f"int64_m{abs(pyvalue)}"
401
+ elif (
402
+ isinstance(pyvalue, list) and len(pyvalue) == 1 and isinstance(pyvalue[0], int)
403
+ ):
404
+ if pyvalue[0] >= 0:
405
+ suggested_name = f"int64_{pyvalue[0]}_1d"
406
+ else:
407
+ suggested_name = f"int64_m{abs(pyvalue[0])}_1d"
408
+ else:
409
+ suggested_name = "const"
410
+ ovar = self.generate_unique_name(suggested_name)
411
+ try:
412
+ tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue)
413
+ except ValueError as e:
414
+ fail(info.msg(str(e)))
415
+ attr = self._make_onnx_attr("value", tensor)
416
+ self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr])
417
+ return Variable(ovar, True)
418
+
419
+ def _emit_copy(self, original_var: str, suggested_name: str) -> str:
420
+ """Emits a copy statement, using the ONNX Identity operator."""
421
+ new_var = self.generate_unique_name(suggested_name)
422
+ self.emit([new_var], "Identity", [original_var])
423
+ return new_var
424
+
425
+ def _is_constant_expr(self, node: ast.AST) -> None:
426
+ if isinstance(node, ast.UnaryOp):
427
+ return self._is_constant_expr(node.operand)
428
+ if isinstance(
429
+ node,
430
+ (
431
+ ast.Call,
432
+ ast.BinOp,
433
+ ast.UnaryOp,
434
+ ast.Compare,
435
+ ast.Attribute,
436
+ ast.List,
437
+ ast.Load,
438
+ ast.Constant,
439
+ ),
440
+ ):
441
+ return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node))
442
+ return False
443
+
444
+ def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
445
+ """Evaluates a sub-expression that is assumed to represent a constant value.
446
+ The expression can refer only to global names (inherited from the scope
447
+ where the script is evaluated) and cannot refer to local names defined
448
+ within the script.) Further, these expressions are assumed to be constants.
449
+ Thus, any subsequent mutation of any state/variables (used in computing
450
+ this constant value) will potentially lead to unexpected behavior (such
451
+ as divergence between eager-mode execution and evaluation of the ONNX
452
+ function.)
453
+ """
454
+ # TODO: assert (self._is_constant_expr(expr))
455
+ # TODO: Refine types
456
+ locals: dict[Any, Any] = {}
457
+ expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset)
458
+ cpl = compile(expr, filename="<ast>", mode="eval")
459
+ try:
460
+ return eval(cpl, self.globals, locals) # pylint: disable=eval-used
461
+ except NameError as e:
462
+ raise NameError(
463
+ self._message(
464
+ expr,
465
+ f"Missing names, globals contains {list(self.globals)!r}, "
466
+ f"locals {list(locals)!r}.",
467
+ )
468
+ ) from e
469
+
470
+ def _translate_attr(
471
+ self,
472
+ attr_name: str,
473
+ expr: ast.AST,
474
+ attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None,
475
+ ) -> Optional[irbuilder.IRAttributeValue]:
476
+ """Translate an attribute-value specification of the form `attr_name=<expr>`
477
+ in a call to an op. expr is an AST. The following cases are supported:
478
+ * Expr evaluates to a script-time constant (a python-value) that can be mapped
479
+ into an ONNX attribute value, or
480
+ * Expr evaluates to None, in which case None is returned, or
481
+ * Expr must be an attribute-reference, that is a name representing an
482
+ attribute-parameter of a containing function.
483
+ """
484
+
485
+ if isinstance(expr, ast.Name):
486
+ val = self._lookup(expr.id, self._source_of(expr))
487
+ if isinstance(val, values.AttrRef):
488
+ attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo)
489
+ if attr_meta is not None and (attr_ref.type != attr_meta.type):
490
+ self.fail(
491
+ expr,
492
+ f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
493
+ )
494
+ return attr_ref
495
+ if isinstance(val, irbuilder.IRFunction):
496
+ # Check that outer-scope variables referenced by function have same value
497
+ # at function-definition site and use-as-attribute site, to avoid errors.
498
+ for pyvar, previous in val.outer_scope_variables:
499
+ current = self._lookup(pyvar, self._source_of(expr))
500
+ if current.value != previous.value:
501
+ self.fail(
502
+ expr,
503
+ f"Outer scope variable '{pyvar}' referenced by function "
504
+ f"'{expr.id!r}' modified.",
505
+ )
506
+
507
+ # Create GraphProto attribute
508
+ val = val.to_graph_proto()
509
+ else:
510
+ val = self._eval_constant_expr(expr)
511
+
512
+ # In ONNX, there is no way to explicitly specify a None value for an attribute.
513
+ # Instead, the attribute must be omitted from the attribute list.
514
+ # Hence, we do not create an attribute-proto if the value is None.
515
+ # The caller is responsible for omitting such attribute-values from the list of attributes
516
+ # in a NodeProto.
517
+ if val is None:
518
+ if attr_meta and attr_meta.required:
519
+ self.fail(expr, f"Attribute '{attr_name}' is required.")
520
+ return None
521
+ attr_type = int(attr_meta.type) if attr_meta else None
522
+ attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type)
523
+ if attr_meta and (attr.type != attr_meta.type):
524
+ self.fail(
525
+ expr,
526
+ f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'",
527
+ )
528
+ return attr
529
+
530
+ def _translate_docstring(self, node: ast.Expr) -> None:
531
+ if hasattr(node.value, "value"):
532
+ # python 3.8+
533
+ return self.ir_builder.add_docstring(self._current_fn, node.value.value)
534
+ raise TypeError(
535
+ f"Unexpected type {type(node)!r} for node. Unsupoorted version of python."
536
+ )
537
+
538
+ def _translate_expr(
539
+ self, node: ast.AST, target: Optional[PreferredName] = None
540
+ ) -> Variable:
541
+ """Expression-translation generates "IR statements/nodes" that compute the value of
542
+ the expression into a target-variable, and returns the variable that is
543
+ assigned this value.
544
+ """
545
+ if isinstance(node, ast.Call):
546
+ r = self._translate_call_expr(node)
547
+ elif isinstance(node, (ast.BinOp, ast.BitAnd, ast.BitOr)):
548
+ r = self._translate_binary_op_expr(node)
549
+ elif isinstance(node, ast.UnaryOp):
550
+ r = self._translate_unary_op_expr(node)
551
+ elif isinstance(node, ast.Compare):
552
+ r = self._translate_compare_expr(node)
553
+ elif isinstance(node, ast.Name):
554
+ r = self._translate_name_expr(node)
555
+ elif isinstance(node, ast.Subscript):
556
+ r = self._translate_subscript_expr(node, target)
557
+ elif self._is_constant_expr(node):
558
+ r = self._emit_const(self._eval_constant_expr(node), target, self._source_of(node))
559
+ else:
560
+ raise ValueError(
561
+ self._message(node, f"Unsupported expression type {type(node)!r}.")
562
+ )
563
+ if isinstance(r, Variable):
564
+ return r
565
+ callee, args, attrs = r
566
+ target = "tmp" if target is None else target
567
+ assert isinstance(target, str)
568
+ result = self.generate_unique_name(target)
569
+ self.emit([result], callee, args, attrs)
570
+ return Variable(result)
571
+
572
+ def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]:
573
+ """Translation of an expression where "None" is permitted (eg., for an optional argument).
574
+ None is represented as a Constant in Python 3.9+.
575
+ """
576
+ if isinstance(node, ast.Constant) and (node.value is None):
577
+ return None
578
+ return self._translate_expr(node)
579
+
580
+ def _translate_subscript_expr(
581
+ self, node: ast.Subscript, target: Optional[PreferredName]
582
+ ) -> Variable:
583
+ """List of supported syntaxes is below.
584
+ `A` is a tensor or an expression equivalent to a tensor.
585
+
586
+ ::
587
+
588
+ A[:, 1]
589
+ A[:2, 0]
590
+ A[:2, :1]
591
+ A[2:0:-1]
592
+ A[1:]
593
+ A[:2]
594
+ A[1:-1]
595
+ A[1:2]
596
+ A[-1]
597
+ A[0]
598
+ A[:0:-1]
599
+
600
+ *i* is a tensor holding one integer.
601
+
602
+ ::
603
+
604
+ A[i]
605
+ A[i+1:i+2]
606
+
607
+ Fully supported for python 3.9+.
608
+
609
+ ::
610
+
611
+ A[i:i+j, k]
612
+
613
+ Not supported:
614
+
615
+ ::
616
+
617
+ A[::-1]
618
+ """
619
+ var = self._translate_expr(node.value)
620
+ var_name = var.name
621
+ if target is None:
622
+ target = f"{var_name}_subscripted"
623
+ target = self.generate_unique_name(target)
624
+ indices = ast_utils.normalize_subscript_expr(node)
625
+ info = self._source_of(node.slice)
626
+
627
+ # Create cached int constants:
628
+ # TODO: Do this at a graph-scope level.
629
+ cached_int_consts = {}
630
+
631
+ def const_1d(value, name: Optional[str] = None):
632
+ nonlocal cached_int_consts
633
+ if value not in cached_int_consts:
634
+ cached_int_consts[value] = self._emit_const([value], name, info)
635
+ return cached_int_consts[value]
636
+
637
+ def one_1d():
638
+ return const_1d(1)
639
+
640
+ # Max/min 64-bit int values are used to represent default values for start/stop in Slice.
641
+ maxint = (1 << 63) - 1
642
+ minint = -(1 << 63)
643
+
644
+ def translate_slice_component(
645
+ node_arg, default_value: Optional[int] = None
646
+ ) -> tuple[str, Optional[int]]:
647
+ """Translate optional start/stop/step component of a Slice expression."""
648
+ if node_arg is None:
649
+ if default_value is None:
650
+ # TODO: Emit "Where(step > 0, pos_default, neg_default)"
651
+ raise RuntimeError(
652
+ "Default start/stop not supported when step direction is unknown."
653
+ )
654
+ return const_1d(default_value), default_value
655
+
656
+ if self._is_constant_expr(node_arg):
657
+ cst = self._eval_constant_expr(node_arg)
658
+ if isinstance(cst, int):
659
+ return const_1d(cst), cst
660
+ else:
661
+ raise RuntimeError(f"Slice component type must be int, not {type(cst)}")
662
+ else:
663
+ name = self._translate_expr(node_arg).name
664
+ reshaped = self.generate_unique_name(f"{name}_reshaped")
665
+ self.emit(
666
+ [reshaped],
667
+ values.Op(self.default_opset, "Reshape"),
668
+ [name, one_1d().name],
669
+ [],
670
+ )
671
+ return reshaped, None
672
+
673
+ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
674
+ """Translate slice-expression of the form from:to:step."""
675
+ step_name, step = translate_slice_component(slice_expr.step, 1)
676
+ if step is None:
677
+ # Step direction unknown.
678
+ # TODO: Handle default-values using runtime check on sign of step.
679
+ lower_name, _ = translate_slice_component(slice_expr.lower, None)
680
+ upper_name, _ = translate_slice_component(slice_expr.upper, None)
681
+ elif step > 0:
682
+ lower_name, _ = translate_slice_component(slice_expr.lower, 0)
683
+ upper_name, _ = translate_slice_component(slice_expr.upper, maxint)
684
+ else:
685
+ lower_name, _ = translate_slice_component(slice_expr.lower, maxint)
686
+ upper_name, _ = translate_slice_component(slice_expr.upper, minint)
687
+ return (lower_name, upper_name, step_name)
688
+
689
+ # An input like X[2] is translated into a Gather op.
690
+ # An input like X[1:5:2] is translated into a Slice op.
691
+ # An input like X[2, 3] is translated into a Slice + Squeeze (instead of two Gathers),
692
+ # as an optimization.
693
+ # An input like X[I, J] is translated into two Gathers (which is correct whatever the
694
+ # rank of I and J)
695
+ # To replace multiple Gathers by the Slice we need to know that the index-values
696
+ # are scalars.
697
+
698
+ # As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2),
699
+ # known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":")
700
+ sliced_indices: List[Tuple[int, ast.expr]] = []
701
+ scalar_indices: List[Tuple[int, ast.expr]] = []
702
+ non_scalar_indices: List[Tuple[int, ast.expr]] = []
703
+ for axis, elt in enumerate(indices):
704
+ if isinstance(elt, ast.Slice):
705
+ # Add to sliced_indices, unless it is "::", which is a no-op.
706
+ if not (elt.lower is None and elt.upper is None and elt.step is None):
707
+ sliced_indices.append((axis, elt))
708
+ elif self._is_constant_expr(elt) and isinstance(
709
+ self._eval_constant_expr(elt), int
710
+ ):
711
+ scalar_indices.append((axis, elt))
712
+ else:
713
+ non_scalar_indices.append((axis, elt))
714
+ if not (sliced_indices or scalar_indices or non_scalar_indices):
715
+ # Edge case: no index specified. Eg. A[:, :]
716
+ self.emit([target], "Identity", [var_name])
717
+ return Variable(target)
718
+ if sliced_indices or len(scalar_indices) > 1:
719
+ # We emit a Slice operation if we have any indices like 1:5:2 or if the number of
720
+ # scalar indices (like 2) is more than 1.
721
+ starts = []
722
+ ends = []
723
+ axes = []
724
+ steps = []
725
+ squeezed_axes = []
726
+ for axis, expr in scalar_indices:
727
+ # Treat a scalar index i as slice "i:i+1:1", but squeeze the axis finally.
728
+ # TODO: handle negative i
729
+ index = self._eval_constant_expr(expr)
730
+ squeezed_axes.append(axis)
731
+ kwargs = dict(
732
+ lineno=getattr(expr, "lineno", node.lineno),
733
+ col_offset=getattr(expr, "col_offset", node.col_offset),
734
+ )
735
+ element = ast.Slice(
736
+ ast.Constant(index, **kwargs),
737
+ ast.Constant(index + 1, **kwargs),
738
+ ast.Constant(1, **kwargs),
739
+ )
740
+ sliced_indices.append((axis, element))
741
+ scalar_indices = []
742
+ for axis, element in sliced_indices:
743
+ axis_var = const_1d(axis)
744
+ inputs = translate_slice(element)
745
+ starts.append(inputs[0])
746
+ ends.append(inputs[1])
747
+ axes.append(axis_var.name)
748
+ steps.append(inputs[2])
749
+
750
+ if len(starts) > 1:
751
+ axis_0_attr = self._make_onnx_attr("axis", 0)
752
+ start_name = self.generate_unique_name(f"{var_name}_start")
753
+ self.emit([start_name], "Concat", starts, [axis_0_attr])
754
+
755
+ end_name = self.generate_unique_name(f"{var_name}_end")
756
+ self.emit([end_name], "Concat", ends, [axis_0_attr])
757
+
758
+ axes_name = self.generate_unique_name(f"{var_name}_axis")
759
+ self.emit([axes_name], "Concat", axes, [axis_0_attr])
760
+
761
+ steps_name = self.generate_unique_name(f"{var_name}_step")
762
+ self.emit([steps_name], "Concat", steps, [axis_0_attr])
763
+ else:
764
+ start_name = starts[0]
765
+ end_name = ends[0]
766
+ axes_name = axes[0]
767
+ steps_name = steps[0]
768
+
769
+ if squeezed_axes:
770
+ sliced_name = self.generate_unique_name(f"{var_name}_sliced")
771
+ self.emit(
772
+ [sliced_name],
773
+ "Slice",
774
+ [var_name, start_name, end_name, axes_name, steps_name],
775
+ )
776
+ squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info)
777
+
778
+ if non_scalar_indices: # use temporary to store result of squeeze
779
+ result = self.generate_unique_name(f"{var_name}_squeezed")
780
+ else: # store squeezed result in final target
781
+ result = target
782
+
783
+ self.emit([result], "Squeeze", [sliced_name, squeezed_axes])
784
+ else:
785
+ if non_scalar_indices: # use temporary to store result of Slice
786
+ result = self.generate_unique_name(f"{var_name}_sliced")
787
+ else: # store result of Slice in final target
788
+ result = target
789
+ slice_inputs = [var_name, start_name, end_name, axes_name, steps_name]
790
+ self.emit([result], "Slice", slice_inputs)
791
+ else:
792
+ result = var_name
793
+ non_scalar_indices.extend(scalar_indices)
794
+ if non_scalar_indices:
795
+ last_axis, _ = non_scalar_indices[-1]
796
+ else:
797
+ # TODO(justinchuby): Clarify what last_axis should be when non_scalar_indices is False
798
+ last_axis = None
799
+ for axis, index_expr in non_scalar_indices:
800
+ index_value = self._translate_expr(index_expr)
801
+ axis_attr = self._make_onnx_attr("axis", axis)
802
+ # use Gather to perform indexing
803
+ # Assign gathered value to either temporary or final target
804
+ if axis != last_axis: # use temporary to store result of Gather
805
+ gathered = self.generate_unique_name(f"{var_name}_axis_{axis}")
806
+ else: # store result of Gather in final target
807
+ gathered = target
808
+ self.emit([gathered], "Gather", [str(result), index_value], [axis_attr])
809
+ result = gathered
810
+
811
+ return Variable(result)
812
+
813
+ def _translate_call_expr(self, node: ast.Call):
814
+ """Translates a call-expression."""
815
+ callee = self._translate_callee_expr(node.func)
816
+ param_schemas = callee.param_schemas()
817
+ # If the callee's schema is available, we use it to determine the inputs and attributes.
818
+ # Otherwise, we map named arguments to attributes and positional arguments to inputs.
819
+ if param_schemas:
820
+ kwargs = {x.arg: x.value for x in node.keywords}
821
+ args, attrs = param_manipulation.separate_input_attributes_from_arguments(
822
+ param_schemas, node.args, kwargs, fill_defaults=False
823
+ )
824
+ args = [self._translate_opt_expr(x) for x in args]
825
+ attrs = [
826
+ self._translate_attr(x, y, callee.op_schema.attributes[x])
827
+ for x, y in attrs.items()
828
+ ]
829
+ else:
830
+ args = [self._translate_opt_expr(x) for x in node.args]
831
+ attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords]
832
+ args = autocast.static_cast_inputs(self, callee.op_schema, args)
833
+
834
+ # In ONNX, there is no way to explicitly specify a None value for an attribute.
835
+ # Instead, the attribute must be omitted from the attribute list.
836
+ # Hence, we do not create an attribute-proto if the value is None.
837
+ attrs = [attr for attr in attrs if attr is not None]
838
+ return callee, args, attrs
839
+
840
+ def _cast_like_binary_expression(self, op, left, right):
841
+ schema = op.op_schema
842
+ return autocast.static_cast_inputs(self, schema, (left, right))
843
+
844
+ def _translate_binary_op_expr(self, node: ast.BinOp):
845
+ op = type(node.op)
846
+ if op not in primop_map:
847
+ raise ValueError(self._message(node, f"Unsupported operator {op!r}."))
848
+
849
+ attr = []
850
+ if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right):
851
+ # specific case X % f where f is a float.
852
+ # attribute fmod=1 is added in that case.
853
+ cst = self._eval_constant_expr(node.right)
854
+ if isinstance(cst, float):
855
+ attr = [self._make_onnx_attr("fmod", 1)]
856
+
857
+ op = values.Op(self.default_opset, primop_map[op])
858
+ left, right = self._cast_like_binary_expression(
859
+ op, self._translate_expr(node.left), self._translate_expr(node.right)
860
+ )
861
+ return op, [left, right], attr
862
+
863
+ def _translate_unary_op_expr(self, node):
864
+ op = type(node.op)
865
+ if op not in primop_map:
866
+ raise ValueError(self._message(node, self).msg(f"Unsupported operator {op!r}."))
867
+ if self._is_constant_expr(node.operand):
868
+ # This function changed the constant node.operand
869
+ # and returns it. The function calling this one
870
+ # should intercept this call and replace node
871
+ # by node.operand.
872
+ # This mechanism does not handle somthing like `(-(-5))`.
873
+ if hasattr(node.operand, "value"):
874
+ # python 3.8+
875
+ val = node.operand.value
876
+ else:
877
+ raise TypeError(
878
+ f"Unable to guess constant value from type {type(node.operand)!r} "
879
+ f"and attributes {dir(node.operand)!r}."
880
+ )
881
+ if op == ast.USub:
882
+ cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset)
883
+ return self._translate_expr(cst)
884
+ if op == ast.UAdd:
885
+ return self._translate_expr(node.operand)
886
+ opname = primop_map[op]
887
+ operand = self._translate_expr(node.operand)
888
+ return values.Op(self.default_opset, opname), [operand], []
889
+
890
+ def _translate_compare_expr(self, node):
891
+ # TODO: handle multiple comparisons in one expression
892
+ assert len(node.ops) == 1
893
+ assert len(node.comparators) == 1
894
+ op = type(node.ops[0])
895
+ if op not in primop_map:
896
+ raise ValueError(self._message(node, f"Unsupported operator {op!r}."))
897
+ opname = primop_map[op]
898
+ left = self._translate_expr(node.left)
899
+ right = self._translate_expr(node.comparators[0])
900
+
901
+ # NotEqual is not a standard ONNX op, and needs to be translated into
902
+ # an Equal op/node followed by a Not op/node.
903
+ op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal")
904
+ left, right = self._cast_like_binary_expression(op, left, right)
905
+ if opname == "NotEqual":
906
+ tmp = self.generate_unique_name()
907
+ self.emit([tmp], op, [left, right])
908
+ not_op = values.Op(self.default_opset, "Not")
909
+ return not_op, [tmp], []
910
+
911
+ return op, [left, right], []
912
+
913
+ def _translate_name_expr(self, node: ast.Name) -> Variable:
914
+ return self._py_var_to_onnx_var(node.id, self._source_of(node))
915
+
916
+ # pylint: disable=inconsistent-return-statements
917
+ def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset:
918
+ """Return an Opset"""
919
+ if isinstance(node, ast.Name):
920
+ val = self._lookup(node.id, self._source_of(node), raise_exception=False)
921
+ if isinstance(val, values.Opset):
922
+ return val
923
+ self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.")
924
+ elif isinstance(node, ast.Attribute):
925
+ self.fail(node, "Nested module unimplemented.") # TODO
926
+ else:
927
+ self.fail(node, "Invalid opset expression.")
928
+
929
+ # pylint: enable=inconsistent-return-statements
930
+ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable=R1710
931
+ """Return an Op"""
932
+ if isinstance(node, ast.Attribute):
933
+ module = self._translate_opset_expr(node.value)
934
+ self._set_default_opset(module, node)
935
+ opname = node.attr
936
+ if opname in module:
937
+ return values.Op(module, node.attr)
938
+ return values.Op(module, node.attr)
939
+ if isinstance(node, ast.Name):
940
+ function_name = node.id
941
+ found = self._lookup(function_name, self._source_of(node), raise_exception=False)
942
+ if isinstance(found, onnxscript.OnnxFunction):
943
+ self._current_fn.add_called_function(found)
944
+ return found
945
+ if isinstance(found, values.Op):
946
+ return found
947
+ if not found:
948
+ if function_name not in self.default_opset:
949
+ warn(
950
+ f"Unknown function name {function_name!r}. "
951
+ f"The ONNX graph may not work."
952
+ )
953
+ return values.Op(self.default_opset, function_name)
954
+ self.fail(node, "Invalid callee")
955
+
956
+ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
957
+ """Statement translation: A single Python statement is mapped into a
958
+ sequence of IR statements.
959
+ """
960
+ if isinstance(node, ast.Assign):
961
+ return self._translate_assign_stmt(node)
962
+ if isinstance(node, ast.AnnAssign):
963
+ return self._translate_assign_stmt(node)
964
+ if isinstance(node, ast.Return):
965
+ if index_of_stmt is not None:
966
+ return self._translate_return_stmt(node)
967
+ raise ValueError(
968
+ self._message(
969
+ node, "Return statements are not permitted inside control-flow statements."
970
+ )
971
+ )
972
+ if isinstance(node, ast.If):
973
+ return self._translate_if_stmt(node)
974
+ if isinstance(node, (ast.For, ast.While)):
975
+ return self._translate_loop_stmt(node)
976
+ if ast_utils.is_doc_string(node):
977
+ if index_of_stmt == 0:
978
+ return self._translate_docstring(node)
979
+ return None
980
+ if isinstance(node, ast.FunctionDef):
981
+ return self._translate_nested_function_def(node)
982
+ if ast_utils.is_print_call(node):
983
+ return None
984
+ raise ValueError(self._message(node, f"Unsupported statement type '{type(node)!r}'."))
985
+
986
+ def _translate_assign_stmt(self, stmt: Union[ast.Assign, ast.AnnAssign]) -> None:
987
+ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
988
+ if isinstance(lhs, ast.Name):
989
+ # Assignments of the form "x = SomeExpression"
990
+ info = self._source_of(lhs)
991
+ lhs = lhs.id
992
+ t = self._translate_expr(rhs, lhs).name
993
+ if isinstance(stmt, ast.AnnAssign):
994
+ typeinfo = self._eval_constant_expr(stmt.annotation)
995
+ else:
996
+ typeinfo = None
997
+ var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo)
998
+ self._bind(lhs, var)
999
+ elif isinstance(lhs, ast.Tuple):
1000
+ # Assignments of the form "x, y, z = op.SomeOp(...)"
1001
+ if not isinstance(rhs, ast.Call):
1002
+ self.fail(
1003
+ rhs,
1004
+ f"RHS must be a Call expression for unpacking, found: '{type(rhs)!r}'",
1005
+ )
1006
+ callee, inputs, attrs = self._translate_call_expr(rhs)
1007
+
1008
+ def generate_onnx_name(x: ast.AST):
1009
+ if not isinstance(x, ast.Name):
1010
+ self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'")
1011
+ onnx_name = self.generate_unique_name(x.id)
1012
+ self._bind(
1013
+ x.id,
1014
+ values.Dynamic(
1015
+ onnx_name, values.DynamicKind.Intermediate, self._source_of(x)
1016
+ ),
1017
+ )
1018
+ return onnx_name
1019
+
1020
+ outputs = [generate_onnx_name(x) for x in lhs.elts]
1021
+ self.emit(outputs, callee, inputs, attrs)
1022
+ else:
1023
+ self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'")
1024
+
1025
+ if isinstance(stmt, ast.Assign):
1026
+ targets = stmt.targets
1027
+ else:
1028
+ targets = [stmt.target]
1029
+ if len(targets) != 1:
1030
+ # Assignments of the form "x = y = SomeExpression"
1031
+ self.fail(stmt, "Multi-assignment not supported.")
1032
+ lhs = targets[0]
1033
+ rhs = stmt.value
1034
+ if isinstance(rhs, ast.Tuple):
1035
+ # Assignments of the form "... = Expression1, Expression2"
1036
+ if not isinstance(lhs, ast.Tuple):
1037
+ # Assignments of the form "single_var = Expression1, Expression2".
1038
+ # We do not support tuple-typed variables.
1039
+ self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.")
1040
+ # Parallel assignments of the form "x, y = Expression1, Expression2"
1041
+ if len(lhs.elts) != len(rhs.elts):
1042
+ self.fail(
1043
+ stmt, "Expected same number of elements on lhs and rhs of assignments."
1044
+ )
1045
+ for p, r in zip(lhs.elts, rhs.elts):
1046
+ assign(p, r)
1047
+ else:
1048
+ assign(lhs, rhs)
1049
+
1050
+ def _translate_return_stmt(self, stmt: ast.Return) -> None:
1051
+ def check_num_outputs(n):
1052
+ if self.returntype is not None:
1053
+ if n != len(self.returntype):
1054
+ raise SyntaxError(
1055
+ self._message(
1056
+ stmt,
1057
+ f"Mismatch in number of return values and types. Keyword "
1058
+ f"'return' cannot be used in a subgraph (test, loop). "
1059
+ f"returntype is {self.returntype!r}, num_outputs={n!r}.",
1060
+ )
1061
+ )
1062
+
1063
+ def ret(exp, i, suffix):
1064
+ preferred_name = f"return_val{suffix}"
1065
+ return_var = self._translate_expr(exp, preferred_name).name
1066
+ val = self._lookup(return_var, self._source_of(exp), False)
1067
+ if val and val.kind == values.DynamicKind.Input:
1068
+ # In ONNX, a graph-input cannot be an output of the graph.
1069
+ # We need to insert a copy.
1070
+ return_var = self._emit_copy(return_var, preferred_name)
1071
+ for prev_output in self._current_fn.outputs:
1072
+ if prev_output.name == return_var:
1073
+ # ONNX does not allow duplicate output names.
1074
+ return_var = self._emit_copy(return_var, f"{return_var}_copy")
1075
+ break
1076
+ if self.returntype is None:
1077
+ t = None
1078
+ else:
1079
+ t = self.returntype[i]
1080
+ self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt))
1081
+ return return_var
1082
+
1083
+ val = stmt.value
1084
+ assert val is not None, "Return statement without return-value not supported."
1085
+ if isinstance(val, ast.Tuple):
1086
+ check_num_outputs(len(val.elts))
1087
+ return [ret(exp, i, str(i)) for i, exp in enumerate(val.elts)]
1088
+ check_num_outputs(1)
1089
+ return ret(val, 0, "")
1090
+
1091
+ def _translate_if_stmt(self, stmt: ast.If) -> None:
1092
+ if hasattr(stmt, "live_out"):
1093
+ live_defs = list(
1094
+ stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message))
1095
+ )
1096
+ else:
1097
+ live_defs = list(analysis.assigned_vars(stmt, self._message))
1098
+ test = self._translate_expr(stmt.test, "cond").name
1099
+ lineno = self._source_of(stmt).lineno
1100
+ thenGraph, sub_fct_then = self._translate_block(
1101
+ stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt
1102
+ )
1103
+ thenAttr = self._make_onnx_attr("then_branch", thenGraph)
1104
+ elseGraph, sub_fct_else = self._translate_block(
1105
+ stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt
1106
+ )
1107
+ elseAttr = self._make_onnx_attr("else_branch", elseGraph)
1108
+
1109
+ def rename(x):
1110
+ r = self.generate_unique_name(x)
1111
+ self._bind(
1112
+ x,
1113
+ values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)),
1114
+ )
1115
+ return r
1116
+
1117
+ # no break condition
1118
+ renamed = [rename(x) for x in live_defs]
1119
+ if not renamed:
1120
+ self.fail(stmt, "A subgraph for a test do not have any output variable.")
1121
+
1122
+ sub_functions = {}
1123
+ sub_functions.update(sub_fct_then)
1124
+ sub_functions.update(sub_fct_else)
1125
+ if renamed == [test]:
1126
+ self.fail(stmt, f"Input and output cannot be the same {renamed!r}.")
1127
+ self.emit(
1128
+ renamed,
1129
+ values.Op(self.default_opset, "If"),
1130
+ [test],
1131
+ [thenAttr, elseAttr],
1132
+ sub_functions=sub_functions,
1133
+ )
1134
+
1135
+ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
1136
+ # loop-variable
1137
+ if isinstance(loop_stmt, ast.For):
1138
+ if not isinstance(loop_stmt.target, ast.Name):
1139
+ self.fail(loop_stmt, "For loop target must be a single variable.")
1140
+ p_loop_var = loop_stmt.target.id
1141
+ # iter
1142
+ iter = loop_stmt.iter
1143
+ assert isinstance(iter, ast.Call), "Loop bound not a call."
1144
+ if not isinstance(iter.func, ast.Name):
1145
+ self.fail(loop_stmt, f"Unsupported loop bound {iter.func!r}.")
1146
+ if iter.func.id != "range":
1147
+ self.fail(
1148
+ loop_stmt, "Unsupported loop bound, only function 'range' is allowed."
1149
+ )
1150
+ if not iter.args or len(iter.args) != 1:
1151
+ self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
1152
+ assert not iter.keywords, "Unsupported loop bound."
1153
+ o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name
1154
+ o_cond_var = self.generate_unique_name("cond_in")
1155
+ i_cond_var = o_cond_var
1156
+ cond_while = None
1157
+ o_loop_condition = "" # No condition for a for loop.
1158
+ elif isinstance(loop_stmt, ast.While):
1159
+ test = loop_stmt.test
1160
+ if not isinstance(test, ast.Name):
1161
+ self.fail(
1162
+ loop_stmt,
1163
+ "Unexpected condition type {type(loop_stmt)!r} for a while loop, "
1164
+ "it should be 'while <condition_name>:'.",
1165
+ )
1166
+ p_loop_var = "infinite_loop"
1167
+ o_loop_bound = ""
1168
+ i_cond_var = test.id
1169
+ cond_while = test.id
1170
+ o_cond_var = None
1171
+ o_loop_condition = self._translate_name_expr(test)
1172
+ # we need to go through all the instructions to see
1173
+ # which instruction defines the condition test.id
1174
+ else:
1175
+ self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.")
1176
+ # analyze loop body
1177
+ exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message)
1178
+ vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message)
1179
+ loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out)
1180
+ scan_outputs = set() # TODO
1181
+ outputs = list(loop_state_vars | scan_outputs)
1182
+
1183
+ # loop-condition:
1184
+ # o_loop_condition = self._emit_const(True, "true", self._source_of(loop_stmt))
1185
+
1186
+ # build loop_body
1187
+ self._enter_scope("loop_body", loop_stmt)
1188
+ o_loop_var = self.generate_unique_name(p_loop_var)
1189
+ self.ir_builder.add_input(
1190
+ self._current_fn,
1191
+ o_loop_var,
1192
+ onnx_types.INT64,
1193
+ self._source_of(loop_stmt),
1194
+ )
1195
+ self._bind(
1196
+ p_loop_var,
1197
+ values.Dynamic(o_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)),
1198
+ )
1199
+
1200
+ self.ir_builder.add_input(
1201
+ self._current_fn,
1202
+ i_cond_var,
1203
+ onnx_types.BOOL,
1204
+ self._source_of(loop_stmt),
1205
+ )
1206
+
1207
+ for pv in loop_state_vars:
1208
+ ov = self.generate_unique_name(pv)
1209
+ # TODO: retrieve the annotation for variable pv is any is specified.
1210
+ # typeinfo = self._eval_constant_expr(pv.annotation)
1211
+ typeinfo = None
1212
+ self.ir_builder.add_input(
1213
+ self._current_fn, ov, typeinfo, self._source_of(loop_stmt)
1214
+ )
1215
+ self._bind(
1216
+ pv,
1217
+ values.Dynamic(ov, values.DynamicKind.Loop, self._source_of(loop_stmt)),
1218
+ )
1219
+
1220
+ condition_name = None
1221
+ operator_name = "Identity"
1222
+ for i, s in enumerate(loop_stmt.body):
1223
+ # We first need to intercept a break instruction in test block.
1224
+ # It must be something like `if <condition_name>: break`.
1225
+ # This instruction must be the last of the loop body.
1226
+ if isinstance(s, ast.If) and len(s.body) == 1 and isinstance(s.body[0], ast.Break):
1227
+ if not isinstance(s.test, ast.Name):
1228
+ self.fail(
1229
+ s,
1230
+ f"Instruction break can be introduced with test but it must be "
1231
+ f"if <condition>: break. However condition is of type "
1232
+ f"{type(s.test)!r}.",
1233
+ )
1234
+ if i != len(loop_stmt.body) - 1:
1235
+ self.fail(s, "Instruction break must be the last one of the loop.")
1236
+
1237
+ current_scope = self._current_scope()
1238
+ if s.test.id not in current_scope:
1239
+ self.fail(
1240
+ loop_stmt,
1241
+ f"Unable to find condition variable {s.test.id!r} in known "
1242
+ f"variables {list(current_scope)!r}.",
1243
+ )
1244
+ condition_name = current_scope[s.test.id].value
1245
+ operator_name = "Not"
1246
+ continue
1247
+ self._translate_stmt(s)
1248
+
1249
+ o_cond_out = self.generate_unique_name("cond_out")
1250
+
1251
+ if cond_while is not None:
1252
+ # Loop while
1253
+ current_scope = self._current_scope()
1254
+ if cond_while not in current_scope:
1255
+ self.fail(
1256
+ loop_stmt,
1257
+ f"Unable to find condition variable {cond_while!r} in known "
1258
+ f"variables {list(current_scope)!r}.",
1259
+ )
1260
+ o_cond_var = current_scope[cond_while].value
1261
+
1262
+ self.emit(
1263
+ [o_cond_out],
1264
+ values.Op(self.default_opset, operator_name),
1265
+ [condition_name or o_cond_var],
1266
+ [],
1267
+ )
1268
+
1269
+ self.ir_builder.add_output(
1270
+ self._current_fn,
1271
+ o_cond_out,
1272
+ onnx_types.BOOL,
1273
+ self._source_of(loop_stmt),
1274
+ )
1275
+ for pv in loop_state_vars:
1276
+ ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name
1277
+ if ov not in self._current_fn.assigned_names:
1278
+ # When converting the loop-body into a graph, we need to handle
1279
+ # identity assignments of the form "x = y" inside the loop body
1280
+ # specially if y represents a value computed outside the loop body.
1281
+ # In this case, we create a copy of y, treating the statement as
1282
+ # shorthand for "x = op.Identity(y)".
1283
+ ov = self._emit_copy(ov, pv)
1284
+ # TODO: retrieve variable type for the annotation if any.
1285
+ typeinfo = None
1286
+ self.ir_builder.add_output(
1287
+ self._current_fn, ov, typeinfo, self._source_of(loop_stmt)
1288
+ )
1289
+ body = self._exit_scope()
1290
+ inputs = [o_loop_bound, o_loop_condition] + [
1291
+ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name
1292
+ for pv in loop_state_vars
1293
+ ]
1294
+ graph, sub_functions = body.to_graph_and_functions()
1295
+ attrs = [self._make_onnx_attr("body", graph)]
1296
+ info = self._source_of(loop_stmt)
1297
+
1298
+ def rename(x):
1299
+ r = self.generate_unique_name(x)
1300
+ self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info))
1301
+ return r
1302
+
1303
+ onnx_outputs = [rename(x) for x in outputs]
1304
+ self.emit(
1305
+ onnx_outputs,
1306
+ "Loop",
1307
+ inputs,
1308
+ attrs,
1309
+ sub_functions=sub_functions,
1310
+ )
1311
+
1312
+ def _translate_block(
1313
+ self,
1314
+ stmts: Sequence[ast.stmt],
1315
+ name: str,
1316
+ live_defs: Sequence[str],
1317
+ parent_stmt: ast.stmt,
1318
+ ):
1319
+ """Translation of a statement-block to GraphProto attribute."""
1320
+ info_stmt = stmts[0] if len(stmts) > 0 else parent_stmt
1321
+ source = self._source_of(info_stmt)
1322
+ self._enter_scope(name, None)
1323
+ for s in stmts:
1324
+ self._translate_stmt(s)
1325
+ for pvar in live_defs:
1326
+ if pvar in self._current_scope():
1327
+ pv_val = self._current_scope()[pvar]
1328
+ output = self._to_onnx_var(pv_val, pvar).name
1329
+ if output not in self._current_fn.assigned_names:
1330
+ # To return an outer-scope variable, an ONNX Graph has to
1331
+ # use an explicit copy via Identity.
1332
+ output = self._emit_copy(output, pvar)
1333
+ self.ir_builder.add_output(
1334
+ self._current_fn,
1335
+ output,
1336
+ pv_val.typeinfo,
1337
+ source,
1338
+ )
1339
+ else:
1340
+ pv_val = None
1341
+ for scope in self._locals: # TODO: skip _current_scope
1342
+ if pvar in scope:
1343
+ pv_val = scope[pvar]
1344
+ break
1345
+ if pv_val is None:
1346
+ self.fail(
1347
+ stmts[0],
1348
+ f"Variable {pvar} is not assigned a value along a conditional "
1349
+ f"branch, known variables: {list(self._locals)}.",
1350
+ )
1351
+ # introduce a copy
1352
+ ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar).name, pvar)
1353
+
1354
+ # TODO: retrieve the annotation if any.
1355
+ typeinfo = None
1356
+ self.ir_builder.add_output(self._current_fn, ovar, typeinfo, source)
1357
+ graph = self._exit_scope()
1358
+ return graph.to_graph_and_functions()
1359
+
1360
+ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
1361
+ """Translate a nested function definition."""
1362
+ self._enter_scope(fn.name, fn)
1363
+ self._translate_function_def_common(fn)
1364
+ function_ir = self._exit_scope()
1365
+ outer_scope_vars = analysis.outer_scope_variables(fn, self._message)
1366
+ function_ir.outer_scope_variables = [
1367
+ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars
1368
+ ]
1369
+ self._bind(fn.name, function_ir)
1370
+ # TODO: Does not yet handle nested functions within nested functions.
1371
+ self._current_fn.add_nested_function(function_ir)
1372
+
1373
+ def _translate_function_signature_common(
1374
+ self, fn: ast.FunctionDef
1375
+ ) -> irbuilder.IRFunction:
1376
+ """Translate a function signature (top-level or nested)."""
1377
+ args = fn.args
1378
+ if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg:
1379
+ warn(f"{fn.name}: Unsupported feature in function signature.")
1380
+ for i, x in enumerate(args.args):
1381
+ arg_with_default_start_index = len(args.args) - len(args.defaults)
1382
+ if args.defaults and i >= arg_with_default_start_index:
1383
+ default_value = self._eval_constant_expr(
1384
+ args.defaults[i - arg_with_default_start_index]
1385
+ )
1386
+ else:
1387
+ default_value = None
1388
+ if x.annotation:
1389
+ typeinfo = self._eval_constant_expr(x.annotation)
1390
+ if not ta.is_valid_type(typeinfo):
1391
+ self.warn(
1392
+ x.annotation,
1393
+ f"Unsupported type annotation for argument {x.arg}.",
1394
+ )
1395
+ typeinfo = None
1396
+ else:
1397
+ # The code can only be exported as a function.
1398
+ typeinfo = None
1399
+ if typeinfo and ta.is_attr_type(typeinfo):
1400
+ self.ir_builder.add_attr_parameter(
1401
+ self._current_fn,
1402
+ x.arg,
1403
+ ta.pytype_to_attrtype(typeinfo),
1404
+ default_value,
1405
+ )
1406
+ self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
1407
+ else:
1408
+ self.ir_builder.add_input(
1409
+ self._current_fn, x.arg, typeinfo, self._source_of(x)
1410
+ )
1411
+ self._used_vars.add(x.arg)
1412
+ self._bind(
1413
+ x.arg,
1414
+ values.Dynamic(x.arg, values.DynamicKind.Input, self._source_of(x)),
1415
+ )
1416
+ if fn.returns:
1417
+ type_annotation = self._eval_constant_expr(fn.returns)
1418
+ self.returntype = ta.get_return_types(type_annotation)
1419
+ invalid = False
1420
+ for t in self.returntype:
1421
+ if not ta.is_valid_type(t):
1422
+ self.warn(
1423
+ fn.returns,
1424
+ f"Unsupported type annotation for return value {t}.",
1425
+ )
1426
+ invalid = True
1427
+ if invalid:
1428
+ self.returntype = None
1429
+ else:
1430
+ self.returntype = None
1431
+
1432
+ return self._current_fn
1433
+
1434
+ def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
1435
+ """Translate a function definition, including the signature and its body."""
1436
+ logger.debug("Converter:_translate_function_def_common:%s", fn.name)
1437
+ _ = self._translate_function_signature_common(fn)
1438
+ for i, s in enumerate(fn.body):
1439
+ self._translate_stmt(s, index_of_stmt=i)
1440
+ return self._current_fn
1441
+
1442
+ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
1443
+ if isinstance(stmt, ast.FunctionDef):
1444
+ self._init_function_translation()
1445
+ if self.default_opset_ is None:
1446
+ opset = self._find_onnx_opset(stmt)
1447
+ if opset:
1448
+ self._set_default_opset(opset, stmt)
1449
+ domain = self.this_module.domain
1450
+ self._current_fn = self.ir_builder.new_function(stmt.name, domain, True)
1451
+ analysis.do_liveness_analysis(stmt, self._message)
1452
+ fn_ir = self._translate_function_def_common(stmt)
1453
+ fn_ir.debug_print()
1454
+ self.this_module.add_function_def(fn_ir)
1455
+ return fn_ir
1456
+ raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.")
1457
+
1458
+ def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
1459
+ """Translate a (top-level) function signature."""
1460
+ domain = self.this_module.domain
1461
+ self._current_fn = self.ir_builder.new_function(fn.name, domain, True)
1462
+ return self._translate_function_signature_common(fn)
pythonProject/.venv/Lib/site-packages/onnxscript/evaluator.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import abc
6
+ import contextlib
7
+ import pprint
8
+ from typing import (
9
+ Any,
10
+ Callable,
11
+ Mapping,
12
+ Optional,
13
+ Protocol,
14
+ Sequence,
15
+ TypeVar,
16
+ Union,
17
+ runtime_checkable,
18
+ )
19
+
20
+ import numpy as np
21
+ import onnx
22
+ import onnx.defs
23
+ import onnx.reference
24
+ from typing_extensions import TypeAlias
25
+
26
+ from onnxscript import irbuilder, onnx_opset, tensor, values
27
+ from onnxscript._internal import autocast, param_manipulation, utils
28
+
29
+ UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]]
30
+
31
+ EagerModeValue: TypeAlias = Union[Optional["tensor.Tensor"], Sequence["EagerModeValue"]]
32
+
33
+ ExtendedModeValue: TypeAlias = Union[
34
+ Optional["tensor.Tensor"],
35
+ Sequence["ExtendedModeValue"],
36
+ np.ndarray,
37
+ int,
38
+ float,
39
+ bool,
40
+ str,
41
+ ]
42
+
43
+ _T = TypeVar("_T")
44
+
45
+
46
+ def _adapt_to_eager_mode(inputs: ExtendedModeValue) -> tuple[EagerModeValue, bool]:
47
+ """Adapts inputs into representation used by onnxscript eager mode.
48
+
49
+ This does the following transformations:
50
+ * It adds an onnxscript Tensor wrapper around numpy arrays, which
51
+ allows the use of overloaded operators like + to be controlled by onnxscript.
52
+ * It also provides a promotion of scalars into tensors as a convenience.
53
+ This is needed to complement the similar promotion supported by the
54
+ onnxscript converter (for example, when an attribute is promoted and used
55
+ as an input argument).
56
+
57
+ Args:
58
+ inputs: a list/tuple of inputs to an ONNX function
59
+
60
+ Returns:
61
+ a pair (wrapped_inputs, flag) where flag indicates whether any numpy array
62
+ was wrapped into a Tensor.
63
+ """
64
+ has_array = False
65
+
66
+ def adapt(input: ExtendedModeValue) -> EagerModeValue:
67
+ if isinstance(input, np.ndarray):
68
+ nonlocal has_array
69
+ has_array = True
70
+ return tensor.Tensor(input)
71
+ if isinstance(input, tensor.Tensor):
72
+ return input
73
+ if isinstance(input, (bool, float)):
74
+ return tensor.Tensor(np.array(input))
75
+ if isinstance(input, int):
76
+ return tensor.Tensor(np.array(input, dtype=np.int64))
77
+ if input is None:
78
+ return None
79
+ if isinstance(input, list):
80
+ return [adapt(elt) for elt in input]
81
+ if isinstance(input, tuple):
82
+ return tuple(adapt(elt) for elt in input)
83
+ raise TypeError(f"Unexpected input type {type(input)}.")
84
+
85
+ result = adapt(inputs)
86
+ return result, has_array
87
+
88
+
89
+ def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue:
90
+ """Unwraps Tensor wrapper around numpy arrays.
91
+
92
+ Args:
93
+ output: output of an ONNX function, which can be either a single
94
+ onnx value or a list/tuple of onnx values.
95
+
96
+ Returns:
97
+ unwrapped output
98
+ """
99
+ if isinstance(output, tensor.Tensor):
100
+ return output.value
101
+ if output is None:
102
+ return None
103
+ if isinstance(output, list):
104
+ return [_adapt_to_user_mode(elt) for elt in output]
105
+ if isinstance(output, tuple):
106
+ return tuple(_adapt_to_user_mode(elt) for elt in output)
107
+ if isinstance(output, np.ndarray):
108
+ return output
109
+ raise TypeError(f"Unexpected type {type(output)}.")
110
+
111
+
112
+ def _unwrap_tensors_in_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]:
113
+ """Unwrap tensors in a mapping to numpy arrays."""
114
+ new_kwargs = {}
115
+ for k, v in kwargs.items():
116
+ new_kwargs[k] = v
117
+ if isinstance(v, tensor.Tensor):
118
+ new_kwargs[k] = v.value
119
+
120
+ return new_kwargs
121
+
122
+
123
+ @runtime_checkable
124
+ class Evaluator(Protocol):
125
+ """Protocol for evaluating ONNX ops."""
126
+
127
+ def eval(
128
+ self,
129
+ schema: onnx.defs.OpSchema,
130
+ inputs: Sequence[ExtendedModeValue],
131
+ attributes: Mapping[str, Any],
132
+ ):
133
+ """Evaluates an ONNX op.
134
+
135
+ Args:
136
+ schema: The OpSchema of the operator to evaluate.
137
+ inputs: The ONNX inputs to the op.
138
+ attributes: The ONNX attributes to the op.
139
+ """
140
+
141
+ def eval_function(
142
+ self,
143
+ function: values.OnnxFunction,
144
+ args: Sequence[ExtendedModeValue],
145
+ kwargs: Mapping[str, ExtendedModeValue],
146
+ ):
147
+ """Evaluates an OnnxFunction.
148
+
149
+ Args:
150
+ function: The OnnxFunction to evaluate.
151
+ args: The positional arguments to the function.
152
+ kwargs: The keyword arguments to the function.
153
+ """
154
+
155
+
156
+ class BaseEvaluator(Evaluator, abc.ABC):
157
+ """Base class for evaluation of ONNX ops.
158
+
159
+ The execution of onnxscript functions in eager-mode is dispatched to an Evaluator
160
+ instance (or, more precisely, to the eval method of the Evaluator instance).
161
+ The evaluator is expected to transform the input/output/attribute representation
162
+ supported by onnxscript to those expected by a particular backend.
163
+ """
164
+
165
+ def __init__(self, ignore_unknown_function_kwargs: bool = False):
166
+ """Initializes a BaseEvaluator.
167
+
168
+ Args:
169
+ ignore_unknown_function_kwargs: Whether to ignore unknown keyword arguments
170
+ when evaluating an OnnxFunction. This is useful when using the
171
+ evaluator to validate operators programmatically, where
172
+ additional keyword arguments that is not part of the signature
173
+ may be provided to the function.
174
+ """
175
+ self._ignore_unknown_function_kwargs = ignore_unknown_function_kwargs
176
+
177
+ def eval(
178
+ self,
179
+ schema: onnx.defs.OpSchema,
180
+ inputs: Sequence[ExtendedModeValue],
181
+ attributes: Mapping[str, Any],
182
+ ):
183
+ """Evaluates an ONNX op.
184
+
185
+ Args:
186
+ schema: The OpSchema of the operator to evaluate.
187
+ inputs: The ONNX inputs to the op.
188
+ attributes: The ONNX attributes to the op.
189
+ """
190
+ attributes = _unwrap_tensors_in_kwargs(attributes)
191
+ attributes, closure = self.adapt_attributes(schema, attributes)
192
+ inputs = self.adapt_inputs(schema, inputs)
193
+ outputs = self._eval(schema, inputs, attributes, closure)
194
+ return self.adapt_outputs(schema, outputs)
195
+
196
+ def adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]):
197
+ """Transform inputs to the expected format for the evaluator.
198
+
199
+ Enables some syntactic sugar, such as the use of Python scalars,
200
+ in a manner consistent with the translator. See autocast.py for details.
201
+ """
202
+ return autocast.dynamic_cast_inputs(schema, inputs)
203
+
204
+ def adapt_attributes(
205
+ self, schema: onnx.defs.OpSchema, attributes: Mapping[str, ExtendedModeValue]
206
+ ) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]:
207
+ """Transform attributes to the expected format for the evaluator.
208
+
209
+ Returns:
210
+ A closure that can be used to evaluate graph-valued attributes.
211
+ """
212
+ use_graph_attribute = self.use_graph_attribute(schema)
213
+ closure: dict[Any, Any] = {}
214
+ adapted_attributes = {}
215
+ for k, v in attributes.items():
216
+ if isinstance(v, values.OnnxClosure):
217
+ if use_graph_attribute:
218
+ adapted_attributes[k] = v.function_ir.to_graph_proto()
219
+ for pyvar, onnxvar in v.function_ir.outer_scope_variables:
220
+ closure[onnxvar.value] = v.frame.f_locals[pyvar]
221
+ else:
222
+ adapted_attributes[k] = v.function
223
+ elif callable(v):
224
+ raise TypeError(
225
+ f"Error: function-valued attribute {v.__name__} has no graph_proto"
226
+ "attribute. Did you forget to decorate it with @graph?"
227
+ )
228
+ else:
229
+ adapted_attributes[k] = v
230
+ return adapted_attributes, closure
231
+
232
+ def adapt_outputs(self, schema: onnx.defs.OpSchema, outputs: Sequence[EagerModeValue]):
233
+ """Adapt evaluator's output to convention used in onnxscript.
234
+
235
+ Onnxscript uses a tuple/sequence only when number of outputs > 1.
236
+ """
237
+ del schema # unused
238
+ return outputs[0] if len(outputs) == 1 else outputs
239
+
240
+ def use_graph_attribute(self, schema: onnx.defs.OpSchema):
241
+ del schema # unused
242
+ return True
243
+
244
+ @abc.abstractmethod
245
+ def _eval(
246
+ self,
247
+ schema: onnx.defs.OpSchema,
248
+ inputs: Sequence[ExtendedModeValue],
249
+ attributes: Mapping[str, ExtendedModeValue],
250
+ closure: Mapping[str, ExtendedModeValue],
251
+ ) -> EagerModeValue:
252
+ """Evaluates an ONNX op given its schema and inputs/attributes.
253
+
254
+ Args:
255
+ schema: The schema of the op to evaluate.
256
+ inputs: The ONNX inputs to the op.
257
+ attributes: The ONNX attributes to the op.
258
+ closure: The closure to use when evaluating graph-valued attributes.
259
+ """
260
+
261
+ def eval_function(
262
+ self,
263
+ function: values.OnnxFunction,
264
+ args: Sequence[ExtendedModeValue],
265
+ kwargs: Mapping[str, ExtendedModeValue],
266
+ ):
267
+ """Evaluates a function in eager mode.
268
+
269
+ Override this function to change the evaluator's behavior for functions.
270
+
271
+ Args:
272
+ function: The OnnxFunction to evaluate.
273
+ args: The positional arguments to the function.
274
+ kwargs: The keyword arguments to the function.
275
+ """
276
+ param_schemas = function.param_schemas()
277
+ # Split happens in the evaluator instead of the OnnxFunction __call__ method
278
+ # so that evaluators can control behaviors like whether to fill in default values for attributes.
279
+ tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
280
+ param_schemas,
281
+ args,
282
+ kwargs,
283
+ fill_defaults=False,
284
+ allow_extra_kwargs=self._ignore_unknown_function_kwargs,
285
+ )
286
+
287
+ adapted_args: list[ExtendedModeValue] = []
288
+ adapted_kwargs: dict[str, ExtendedModeValue] = {}
289
+ has_array = False
290
+ for arg, param_schema in tagged_args:
291
+ if param_schema.is_input:
292
+ adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
293
+ has_array = has_array or has_array_
294
+ adapted_args.append(adapted_arg)
295
+ else:
296
+ adapted_args.append(arg)
297
+
298
+ for key, (arg, param_schema) in tagged_kwargs.items():
299
+ if param_schema.is_input:
300
+ adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
301
+ has_array = has_array or has_array_
302
+ adapted_kwargs[key] = adapted_arg
303
+ else:
304
+ adapted_kwargs[key] = arg
305
+
306
+ result = function.function(*adapted_args, **adapted_kwargs)
307
+
308
+ # We use a heuristic to decide whether to return output values as
309
+ # numpy arrays or tensor.Tensors. If the function has at least one
310
+ # numpy array as input, we return numpy arrays. Otherwise, we return
311
+ # tensor.Tensors. We could use a user-specified flag to control this
312
+ # or explicitly track whether this is a top-level function-call or
313
+ # a nested function-call.
314
+
315
+ return _adapt_to_user_mode(result) if has_array else result
316
+
317
+
318
+ # Utilities for evaluation using ORT:
319
+
320
+
321
+ class EagerModeError(RuntimeError):
322
+ pass
323
+
324
+
325
+ def _rename_io(prefix, i, arg):
326
+ if arg is None:
327
+ return ""
328
+ return f"{prefix}{i}"
329
+
330
+
331
+ def compute_num_outputs(
332
+ schema: onnx.defs.OpSchema, args: Sequence[Any], kwargs: Mapping[str, Any]
333
+ ) -> int:
334
+ """Returns the number of outputs expected."""
335
+
336
+ # TODO: Use ONNX type inference to replace the special-case handling below.
337
+ if schema.domain == "":
338
+ if schema.name == "BatchNormalization":
339
+ if not kwargs.get("training_mode", 0):
340
+ return 1
341
+ if schema.name == "LSTM":
342
+ return 3
343
+ if schema.name == "Split":
344
+ if len(args) == 1 and "num_outputs" not in kwargs:
345
+ raise EagerModeError(
346
+ "Operator Split: the number of expected outputs defines the split. "
347
+ "This information is unknown here."
348
+ )
349
+ if len(args) == 2: # has argument(split)
350
+ return len(args[1])
351
+ else: # no argument(split), check attribute(num_outputs)
352
+ return kwargs["num_outputs"]
353
+ if schema.name == "Scan":
354
+ scan_body = kwargs["body"]
355
+ return len(scan_body.output)
356
+ if schema.name == "Loop":
357
+ loop_body = kwargs["body"]
358
+ return len(loop_body.output) - 1
359
+ return len(schema.outputs)
360
+
361
+
362
+ def _onnxscript_to_numpy_value(v):
363
+ """Converts an onnxscript encoding of an ONNX value into the numpy encoding used by runtimes."""
364
+ if isinstance(v, tensor.Tensor):
365
+ return v.value
366
+ if isinstance(v, list):
367
+ return [_onnxscript_to_numpy_value(x) for x in v]
368
+ if isinstance(v, tuple):
369
+ if len(v) > 0 and type(v[0]) is int: # pylint: disable=unidiomatic-typecheck
370
+ return np.array(v, dtype=np.int64)
371
+ return np.array(v)
372
+ if v is None:
373
+ # Treated as a static-optional value.
374
+ # Dynamic optional None not yet supported.
375
+ return v
376
+ if isinstance(v, np.ndarray):
377
+ return v
378
+ raise TypeError(
379
+ f"Unexpected onnxscript value type '{type(v)}'."
380
+ "Valid value types are 'Tensor | list[Tensor] | None | np.ndarray | list[np.ndarray]'"
381
+ )
382
+
383
+
384
+ def _numpy_to_onnxscript_value(
385
+ v: np.ndarray | np.generic | list[np.ndarray] | list[np.generic],
386
+ ):
387
+ """Converts an ORT encoding of an ONNX value into the encoding used by onnxscript."""
388
+ if isinstance(v, np.ndarray):
389
+ # ORT may reuse buffers when the output numpy array is provided back as input.
390
+ # We need to make a copy to ensure that the tensor is not modified in-place.
391
+ return tensor.Tensor(v.copy())
392
+ if issubclass(type(v), np.generic):
393
+ # Numpy scalar types that are not ndarray
394
+ # https://numpy.org/doc/stable/reference/arrays.scalars.html
395
+ return tensor.Tensor(np.array(v))
396
+ if isinstance(v, list):
397
+ return [_numpy_to_onnxscript_value(x) for x in v]
398
+ if v is None:
399
+ raise TypeError("Dynamic optional values not yet supported.")
400
+ raise TypeError(
401
+ f"Unexpected runtime value type '{type(v)}'."
402
+ "Valid types are: 'np.ndarray | np.generic | list[np.ndarray] | list[np.generic]'"
403
+ )
404
+
405
+
406
+ def _prepare_model_and_inputs_for_eager(
407
+ schema: onnx.defs.OpSchema,
408
+ args: Sequence[Any],
409
+ kwargs: Mapping[str, Any],
410
+ implicit_args: Optional[Mapping[str, Any]],
411
+ ):
412
+ implicit_args = implicit_args or {}
413
+ # Convert input values to ORT representation-type:
414
+ args = [_onnxscript_to_numpy_value(x) for x in args]
415
+ implicit_args = {k: _onnxscript_to_numpy_value(v) for k, v in implicit_args.items()}
416
+
417
+ # Utility to convert kwarg to ONNX AttributeProto:
418
+ def make_attr(key: str, value: Any) -> onnx.AttributeProto:
419
+ def make_tensor_name() -> str:
420
+ return f"attr_{key}"
421
+
422
+ return autocast.pyvalue_to_onnx_attribute(
423
+ key, value, make_tensor_name, int(schema.attributes[key].type)
424
+ )
425
+
426
+ # Construct ONNX model with a single op call:
427
+ inputs = [_rename_io("input", i, arg) for i, arg in enumerate(args)]
428
+
429
+ num_outputs = compute_num_outputs(schema, args, kwargs)
430
+ outputs = [f"output{i}" for i in range(num_outputs)]
431
+
432
+ node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
433
+ node.attribute.extend(
434
+ make_attr(key, value) for key, value in kwargs.items() if value is not None
435
+ )
436
+ input_value_infos = utils.values_to_value_infos(zip(inputs, args))
437
+ implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
438
+ output_value_infos = [
439
+ onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
440
+ for name in outputs
441
+ ]
442
+
443
+ graph = onnx.helper.make_graph( # noqa: TID251
444
+ [node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
445
+ )
446
+ opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
447
+ model = onnx.helper.make_model( # noqa: TID251
448
+ graph,
449
+ opset_imports=[opset_id],
450
+ ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),
451
+ )
452
+ model = onnx.shape_inference.infer_shapes(model)
453
+
454
+ session_run_input = {name: arg for name, arg in zip(inputs, args) if name != ""}
455
+ session_run_input.update(implicit_args)
456
+
457
+ return model, session_run_input, inputs
458
+
459
+
460
+ def _call_ort(
461
+ schema: onnx.defs.OpSchema,
462
+ args: Sequence[Any],
463
+ kwargs: Mapping[str, Any],
464
+ implicit_args: Optional[Mapping[str, Any]],
465
+ ):
466
+ # Delay import onnxruntime so that onnxscript can be used without
467
+ # installing onnxruntime.
468
+ import onnxruntime as ort # pylint: disable=import-outside-toplevel
469
+ from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
470
+ Fail,
471
+ InvalidArgument,
472
+ InvalidGraph,
473
+ )
474
+
475
+ model, session_run_input, inputs = _prepare_model_and_inputs_for_eager(
476
+ schema, args, kwargs, implicit_args
477
+ )
478
+
479
+ try:
480
+ session = ort.InferenceSession(
481
+ model.SerializeToString(), providers=("CPUExecutionProvider",)
482
+ )
483
+ except (Fail, InvalidGraph, InvalidArgument) as e:
484
+ raise EagerModeError(
485
+ f"Unable to create onnxruntime InferenceSession "
486
+ f"for executing {schema.domain}.{schema.name} op "
487
+ f"with onnx model\n{onnx.printer.to_text(model)}"
488
+ ) from e
489
+
490
+ try:
491
+ result = session.run(None, session_run_input)
492
+ except (RuntimeError, Fail) as e:
493
+ raise EagerModeError(
494
+ f"Unable to execute model operator {schema.name!r} due to {e!r}"
495
+ f"\ninput types:\n"
496
+ f"{pprint.pformat({k: type(v) for k, v in zip(inputs, args)})}"
497
+ f"\nmodified input types:\n"
498
+ f"{pprint.pformat({k: type(v) for k, v in session_run_input.items()})}"
499
+ f"\ninputs:\n{pprint.pformat(session_run_input)}\n{model}"
500
+ ) from e
501
+
502
+ # Map ORT output values to the onnxscript representation-type.
503
+ return [_numpy_to_onnxscript_value(x) for x in result]
504
+
505
+
506
+ def _schema_id(schema: onnx.defs.OpSchema) -> tuple[str, str, int]:
507
+ return schema.name, schema.domain, schema.since_version
508
+
509
+
510
+ class ORTEvaluator(BaseEvaluator):
511
+ """Evaluates ONNX ops using ONNX Runtime."""
512
+
513
+ def _eval(self, schema, inputs, attributes, closure):
514
+ return _call_ort(schema, inputs, attributes, closure)
515
+
516
+
517
+ class OnnxReferenceRuntimeEvaluator(BaseEvaluator):
518
+ """Evaluates ONNX ops using ONNX Runtime."""
519
+
520
+ def _eval(self, schema, inputs, attributes, closure):
521
+ model, session_run_input, adapted_inputs = _prepare_model_and_inputs_for_eager(
522
+ schema, inputs, attributes, closure
523
+ )
524
+ session = onnx.reference.ReferenceEvaluator(model)
525
+ try:
526
+ result = session.run(None, session_run_input)
527
+ except RuntimeError as e:
528
+ raise EagerModeError(
529
+ f"Unable to execute model operator {schema.name!r} due to {e!r}"
530
+ f"\ninput types:\n"
531
+ f"{pprint.pformat({k: type(v) for k, v in zip(adapted_inputs, inputs)})}"
532
+ f"\nmodified input types:\n"
533
+ f"{pprint.pformat({k: type(v) for k, v in session_run_input.items()})}"
534
+ f"\ninputs:\n{pprint.pformat(session_run_input)}\n{model}"
535
+ ) from e
536
+
537
+ return [_numpy_to_onnxscript_value(x) for x in result]
538
+
539
+
540
+ ort_evaluator = ORTEvaluator()
541
+
542
+
543
+ class ORTMixedEvaluator(ORTEvaluator):
544
+ """Evaluates ONNX ops using ONNX Runtime, unless an overriding python implementation is registered.
545
+
546
+ This is useful for higher-order ops such as Scan and SequenceMap, allowing for
547
+ python-based debugging.
548
+ """
549
+
550
+ def __init__(self) -> None:
551
+ super().__init__()
552
+ self._python_ops: dict[tuple[str, str, int], Any] = {}
553
+
554
+ def use_graph_attribute(self, schema: onnx.defs.OpSchema) -> bool:
555
+ return _schema_id(schema) not in self._python_ops
556
+
557
+ def _eval(self, schema, inputs, attributes, closure):
558
+ schemaid = _schema_id(schema)
559
+ if schemaid in self._python_ops:
560
+ return self._python_ops[schemaid](inputs, attributes)
561
+ else:
562
+ return super()._eval(schema, inputs, attributes, closure)
563
+
564
+ def register(self, opset: values.Opset) -> Callable[[_T], _T]:
565
+ assert opset is not None
566
+
567
+ def decorator(function: _T) -> _T:
568
+ schema = opset[function.__name__]
569
+ self._python_ops[_schema_id(schema)] = function
570
+ return function
571
+
572
+ return decorator
573
+
574
+
575
+ ort_mixed_evaluator = ORTMixedEvaluator()
576
+
577
+
578
+ @ort_mixed_evaluator.register(opset=onnx_opset.opset18)
579
+ def SequenceMap(inputs: Sequence[Any], attributes: Mapping[str, Any]):
580
+ """Evaluates a SequenceMap op."""
581
+ fun = attributes["body"]
582
+
583
+ def get_input_of(input_index, iter_num):
584
+ input = inputs[input_index]
585
+ if isinstance(input, list):
586
+ return input[iter_num]
587
+ return input
588
+
589
+ def get_input(iter_num):
590
+ return [get_input_of(input_index, iter_num) for input_index in range(len(inputs))]
591
+
592
+ return [fun(*(get_input(i))) for i in range(len(inputs[0]))]
593
+
594
+
595
+ # Used to control the default evaluator instance. A simple approach for now.
596
+
597
+ _default_evaluator: Evaluator = ort_evaluator
598
+
599
+
600
+ def default() -> Evaluator:
601
+ """Returns the default Evaluator default."""
602
+ return _default_evaluator
603
+
604
+
605
+ def set_default(new_default: Evaluator) -> None:
606
+ """Sets the current Evaluator default."""
607
+ global _default_evaluator # pylint: disable=global-statement
608
+ _default_evaluator = new_default
609
+
610
+
611
+ @contextlib.contextmanager
612
+ def default_as(temp_default: Evaluator):
613
+ """Context manager that temporarily switches the default evaluator."""
614
+ old_default = _default_evaluator
615
+ set_default(temp_default)
616
+ try:
617
+ yield
618
+ finally:
619
+ set_default(old_default)
pythonProject/.venv/Lib/site-packages/onnxscript/irbuilder.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ # ruff: noqa: TID251
4
+ from __future__ import annotations
5
+
6
+ import dataclasses
7
+ import io
8
+ import logging
9
+ import warnings
10
+ from typing import Any, Optional, Protocol, Sequence, Union
11
+
12
+ import onnx
13
+ from onnx import ValueInfoProto, helper
14
+ from onnx.defs import onnx_opset_version
15
+
16
+ import onnxscript
17
+ from onnxscript import type_annotation as ta
18
+ from onnxscript import values
19
+ from onnxscript._internal import version_utils
20
+ from onnxscript.onnx_types import ONNXType
21
+ from onnxscript.sourceinfo import SourceInfo
22
+
23
+ # A simple IR (Function, Stmt, Attr, Var):
24
+
25
+ logger = logging.getLogger("onnxscript")
26
+
27
+
28
+ def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str):
29
+ """Formats a sequence of objects into a string."""
30
+ return prefix + sep.join([formatter(x) for x in seq]) + suffix
31
+
32
+
33
+ def select_ir_version(version: int, domain: str = "") -> int:
34
+ """Selects a suitable ONNX ir_version for a given opset version."""
35
+ if domain == "":
36
+ domain = "ai.onnx"
37
+ if (domain, version) not in helper.OP_SET_ID_VERSION_MAP:
38
+ return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx")
39
+ return helper.OP_SET_ID_VERSION_MAP[domain, version]
40
+
41
+
42
+ class IRType:
43
+ def __init__(self):
44
+ self.onnx_type = onnx.TypeProto()
45
+
46
+ def to_type_proto(self):
47
+ return self.onnx_type
48
+
49
+ def __repr__(self) -> str:
50
+ return "IRType()"
51
+
52
+
53
+ class IRTensorType(IRType):
54
+ def __init__(self, elem_type: onnx.TensorProto.DataType) -> None:
55
+ super().__init__()
56
+ self.onnx_type.tensor_type.elem_type = elem_type
57
+
58
+ def __repr__(self) -> str:
59
+ return f"IRTensorType({self.onnx_type.tensor_type.elem_type})"
60
+
61
+
62
+ class IRTypeLike(Protocol):
63
+ def to_type_proto(self) -> onnx.TypeProto:
64
+ """Converts IR type representation to onnx.TypeProto"""
65
+
66
+
67
+ class IRVar:
68
+ """A variable (representing a formal parameter)."""
69
+
70
+ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None:
71
+ if not isinstance(varname, str):
72
+ raise TypeError(f"varname must be a string not {type(varname)!r}.")
73
+ self.name = varname
74
+ self.info = sourceinfo
75
+ self.typeinfo = typeinfo
76
+
77
+ def __str__(self):
78
+ return self.name
79
+
80
+ def __repr__(self):
81
+ return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})"
82
+
83
+ def typed_str(self):
84
+ return f"{self.name} : {self.typeinfo}"
85
+
86
+ def to_value_info(self, use_default_type: bool = True):
87
+ """Converts the content of this class into :class:`onnx.ValueInfoProto`.
88
+
89
+ Args:
90
+ use_default_type: if True, use a default type if an explicit type
91
+ is not known. Otherwise, returns a ValueInfoProto without type.
92
+
93
+ Returns:
94
+ an instance of :class:`onnx.ValueInfoProto`
95
+ """
96
+ if self.name is None:
97
+ raise ValueError(self.info.msg("name cannot be None."))
98
+ value_info_proto = ValueInfoProto()
99
+ value_info_proto.name = self.name
100
+ if self.typeinfo is not None:
101
+ value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto())
102
+ elif use_default_type:
103
+ value_info_proto.type.CopyFrom(IRType().to_type_proto())
104
+ return value_info_proto
105
+
106
+
107
+ def _opt_var_to_str(x):
108
+ return "" if x is None else str(x)
109
+
110
+
111
+ class IRAttributeValue:
112
+ """An attribute value (representing an actual parameter).
113
+
114
+ Attributes:
115
+ name: The name of the attribute.
116
+ type: The type of the attribute.
117
+ attr_proto: The attribute proto.
118
+ """
119
+
120
+ def __init__(self, attrproto: onnx.AttributeProto) -> None:
121
+ self.attr_proto = attrproto
122
+
123
+ def __str__(self):
124
+ if self.attr_proto.HasField("ref_attr_name"):
125
+ return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}"
126
+ # self.name + " = " + self.value
127
+ return helper.printable_attribute(self.attr_proto)
128
+
129
+ @property
130
+ def name(self) -> str:
131
+ return self.attr_proto.name
132
+
133
+ @property
134
+ def type(self) -> onnx.AttributeProto.AttributeType:
135
+ return self.attr_proto.type
136
+
137
+
138
+ @dataclasses.dataclass(frozen=True)
139
+ class IRAttributeParameter:
140
+ """An attribute parameter (representing a formal parameter).
141
+
142
+ It may or may not carry a default value.
143
+
144
+ Attributes:
145
+ name: The name of the attribute.
146
+ type: The type of the attribute.
147
+ default_value: The default value of the attribute.
148
+ has_default: Whether the attribute has a default value.
149
+ attr_proto: The attribute proto.
150
+ """
151
+
152
+ name: str
153
+ type: onnx.AttributeProto.AttributeType
154
+ default_value: str | int | float | None = None
155
+
156
+ # TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
157
+
158
+ def __str__(self):
159
+ if self.has_default:
160
+ return helper.printable_attribute(self.attr_proto)
161
+ # TODO(justinchuby): Include a readable type name.
162
+ return self.name
163
+
164
+ @property
165
+ def has_default(self):
166
+ return self.default_value is not None
167
+
168
+ @property
169
+ def attr_proto(self) -> onnx.AttributeProto:
170
+ if not self.has_default:
171
+ raise ValueError(
172
+ "Attribute has no default value. Only attributes with default "
173
+ "values can be converted to AttributeProto."
174
+ )
175
+ if version_utils.onnx_older_than("1.15"):
176
+ # TODO(after 1.14 is deprecated): Remove this branch.
177
+ # Argument 'attr_type' was added after version 1.14.
178
+ return helper.make_attribute(self.name, self.default_value)
179
+ # pylint: disable=unexpected-keyword-arg
180
+ return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg]
181
+ # pylint: enable=unexpected-keyword-arg
182
+
183
+
184
+ class IRStmt:
185
+ def __init__(
186
+ self,
187
+ result: Sequence[str],
188
+ callee: values.Op,
189
+ args: Sequence[Optional[str]],
190
+ attrs: Sequence[IRAttributeValue],
191
+ sub_functions=None,
192
+ ) -> None:
193
+ if not isinstance(callee, values.Op):
194
+ raise TypeError(f"Unexpected type {type(callee)} for callee.")
195
+ self.result = result
196
+ self.callee = callee
197
+ self.args = args
198
+ self.attrs = attrs
199
+ self.functions = sub_functions or {}
200
+
201
+ def __str__(self):
202
+ if isinstance(self.result, str):
203
+ logger.debug("unexpected str type for self.result where type(self)=%r", type(self))
204
+ lhs = ", ".join(self.result)
205
+ attrs = ""
206
+ if self.attrs:
207
+ attrs = _format(self.attrs, "<", ", ", ">")
208
+
209
+ args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
210
+ domain = self.callee.opset.domain
211
+ opname = self.callee.name
212
+ callee = f"{domain}.{opname}" if (domain != "") else opname
213
+ return f"{lhs} = {callee} {attrs}{args}"
214
+
215
+ def debug_print(self):
216
+ if logger.isEnabledFor(logging.DEBUG):
217
+ logger.debug("%s: %s", type(self), str(self))
218
+
219
+ def to_node_proto(self, node_name: str) -> onnx.NodeProto:
220
+ n = helper.make_node(
221
+ self.callee.name,
222
+ [_opt_var_to_str(x) for x in self.args],
223
+ [str(x) for x in self.result],
224
+ domain=self.callee.opset.domain,
225
+ name=node_name,
226
+ )
227
+ for a in self.attrs:
228
+ n.attribute.append(a.attr_proto)
229
+ return n
230
+
231
+ @property
232
+ def output_names(self) -> Sequence[str]:
233
+ """Returns the list of variables assigned to by this statement."""
234
+ return [str(x) for x in self.result]
235
+
236
+
237
+ class IRFunction:
238
+ """Represents a function in the IR."""
239
+
240
+ def __init__(self, name: str, domain: str = "") -> None:
241
+ self.domain = domain
242
+ self.name = name
243
+ self.outputs: list[IRVar] = []
244
+ self.stmts: list[IRStmt] = []
245
+ self.called_functions: dict[str, onnx.FunctionProto] = {}
246
+ self.docstring: str = ""
247
+ # a dictionary of nested function-definitions
248
+ self.nested_functions: dict[str, IRFunction] = {}
249
+ self.outer_scope_variables: dict[Any, Any] = {}
250
+ self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = []
251
+
252
+ @property
253
+ def assigned_names(self) -> Sequence[str]:
254
+ """Returns the list of variables assigned to by this function."""
255
+ return [v for stmt in self.stmts for v in stmt.output_names]
256
+
257
+ @property
258
+ def inputs(self) -> Sequence[IRVar]:
259
+ return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)]
260
+
261
+ @property
262
+ def attrs(self) -> Sequence[IRAttributeParameter]:
263
+ return [
264
+ attr
265
+ for attr in self.ordered_inputs_and_attrs
266
+ if isinstance(attr, IRAttributeParameter)
267
+ ]
268
+
269
+ def __str__(self):
270
+ attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else ""
271
+ inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")")
272
+ outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")")
273
+ stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n")
274
+ return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"
275
+
276
+ def append_docstring(self, docstring):
277
+ self.docstring += docstring
278
+
279
+ def append_stmt(self, stmt: IRStmt) -> None:
280
+ self.stmts.append(stmt)
281
+
282
+ def append_input(self, name: IRVar) -> None:
283
+ self.ordered_inputs_and_attrs.append(name)
284
+
285
+ def append_output(self, name: IRVar) -> None:
286
+ self.outputs.append(name)
287
+
288
+ def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
289
+ self.ordered_inputs_and_attrs.append(attr)
290
+
291
+ def debug_print(self):
292
+ if logger.isEnabledFor(logging.DEBUG):
293
+ st = io.StringIO()
294
+ for s in self.stmts:
295
+ for attr in s.attrs:
296
+ if attr.attr_proto.HasField("g"):
297
+ st.write(helper.printable_graph(attr.attr_proto.g))
298
+ st.write("\n")
299
+
300
+ def add_called_function(self, fun: values.OnnxFunction) -> None:
301
+ for name, fct in fun.function_ir.called_functions.items():
302
+ if name in self.called_functions:
303
+ continue
304
+ self.called_functions[name] = fct
305
+ if fun.name in self.called_functions:
306
+ # Already added.
307
+ return
308
+ try:
309
+ proto = fun.to_function_proto()
310
+ except (TypeError, AttributeError) as e:
311
+ raise TypeError(f"Issue with type f{type(fun)}.") from e
312
+ self.called_functions[fun.name] = proto
313
+
314
+ def add_nested_function(self, fun: IRFunction) -> None:
315
+ self.nested_functions[fun.name] = fun
316
+
317
+ def to_model_proto(
318
+ self,
319
+ functions=None,
320
+ io_types: Optional[ONNXType] = None,
321
+ input_types: Optional[Sequence[ONNXType]] = None,
322
+ output_types: Optional[Sequence[ONNXType]] = None,
323
+ value_infos: dict[str, ONNXType] | None = None,
324
+ **kwargs,
325
+ ) -> onnx.ModelProto:
326
+ """Converts this instance into a `onnx.ModelProto`.
327
+
328
+ Args:
329
+ functions: A list of functions to include in the model.
330
+ By default, all functions called at least once are included.
331
+ io_types: When specified, all the inputs/outputs of the model
332
+ are set to be of this type.
333
+ input_types: When specified, all the inputs of the model
334
+ are set to be of the corresponding type in this list.
335
+ output_types: When specified, all the outputs of the model
336
+ are set to be of the corresponding type in this list.
337
+ value_infos: A dictionary mapping intermediate variable names to ONNX types.
338
+ Used to set value_info for intermediate variables.
339
+ kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
340
+
341
+ Returns:
342
+ An instance of :class:`onnx.ModelProto`.
343
+ """
344
+ value_infos = (
345
+ [
346
+ onnx.helper.make_value_info(name, type.to_type_proto())
347
+ for name, type in value_infos.items()
348
+ ]
349
+ if value_infos
350
+ else None
351
+ )
352
+ graph, sub_functions = self.to_graph_and_functions(
353
+ use_default_type=False, value_infos=value_infos
354
+ )
355
+ if io_types is not None:
356
+ for input in graph.input:
357
+ if not input.HasField("type"):
358
+ input.type.CopyFrom(io_types.to_type_proto())
359
+ for output in graph.output:
360
+ if not output.HasField("type"):
361
+ output.type.CopyFrom(io_types.to_type_proto())
362
+ if input_types is not None:
363
+ for input, type in zip(graph.input, input_types):
364
+ input.type.CopyFrom(type.to_type_proto())
365
+ if output_types is not None:
366
+ for output, type in zip(graph.output, output_types):
367
+ output.type.CopyFrom(type.to_type_proto())
368
+ if functions is None:
369
+ functions = sub_functions.values()
370
+ else:
371
+
372
+ def to_proto(f):
373
+ if isinstance(f, onnx.FunctionProto):
374
+ return f
375
+ if isinstance(f, onnxscript.OnnxFunction):
376
+ return f.to_function_proto()
377
+ raise TypeError("Expected a value of type FunctionProto of OnnxFunction")
378
+
379
+ functions = [to_proto(f) for f in functions]
380
+
381
+ opsets = {}
382
+ for n in self.stmts:
383
+ if n.callee.opset.domain not in opsets:
384
+ opsets[n.callee.opset.domain] = n.callee.opset.version
385
+
386
+ for proto in functions:
387
+ if proto.domain not in opsets:
388
+ opsets[proto.domain] = 1
389
+ # TODO(rama): Handle conflicts with appropriate error/warning message.
390
+ for opset in proto.opset_import:
391
+ if opset.domain not in opsets:
392
+ opsets[opset.domain] = opset.version
393
+
394
+ if "" not in opsets:
395
+ # No operator is using the standard opset.
396
+ # A default value is given.
397
+ opsets[""] = onnx_opset_version()
398
+
399
+ if "ir_version" not in kwargs:
400
+ kwargs["ir_version"] = select_ir_version(opsets[""])
401
+ opset_imports = [
402
+ onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
403
+ ]
404
+
405
+ return helper.make_model(
406
+ graph, opset_imports=opset_imports, functions=functions, **kwargs
407
+ )
408
+
409
+ def to_graph_and_functions(
410
+ self,
411
+ use_default_type: bool = True,
412
+ value_infos: Sequence[ValueInfoProto] | None = None,
413
+ ) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
414
+ """Converts this instance into a `onnx.GraphProto` and a map from
415
+ function-name to `onnx.FunctionProto`.
416
+
417
+ Args:
418
+ use_default_type: if True, the function uses a default type
419
+ for inputs and outputs that do not have a type
420
+ value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added
421
+ to the graph.
422
+
423
+ Returns:
424
+ a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto`
425
+ """
426
+ called_functions: dict[str, onnx.FunctionProto] = {}
427
+ for s in self.stmts:
428
+ called_functions.update(s.functions)
429
+ called_functions.update(self.called_functions)
430
+ graph = helper.make_graph(
431
+ [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)],
432
+ self.name,
433
+ [x.to_value_info(use_default_type) for x in self.inputs],
434
+ [y.to_value_info(use_default_type) for y in self.outputs],
435
+ value_info=value_infos,
436
+ )
437
+ return graph, called_functions
438
+
439
+ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto:
440
+ """Converts this instance into a `onnx.GraphProto`.
441
+
442
+ Args:
443
+ use_default_type: if True, the function uses a default type
444
+ for inputs and outputs that do not have a type
445
+
446
+ Returns:
447
+ an instance of :class:`onnx.GraphProto`
448
+ """
449
+ graph, _ = self.to_graph_and_functions(use_default_type=use_default_type)
450
+ return graph
451
+
452
+ def get_opset_import(self) -> dict[str, int]:
453
+ func_opset_imports = {}
454
+ for s in self.stmts:
455
+ if s.callee.opset.domain not in func_opset_imports:
456
+ func_opset_imports[s.callee.opset.domain] = s.callee.opset.version
457
+ elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version:
458
+ warnings.warn(
459
+ f"There is a version conflict in domain: {s.callee.opset.domain!r}, "
460
+ f"with {self.name!r}.",
461
+ category=UserWarning,
462
+ stacklevel=1,
463
+ )
464
+ return func_opset_imports
465
+
466
+ def to_function_proto(self) -> onnx.FunctionProto:
467
+ """Converts this instance into a `onnx.FunctionProto`.
468
+
469
+ Note: Default values for attributes are an experimental feature in ONNX.
470
+ Conversion ignores default values for attributes if the ONNX version installed
471
+ doesn't support it.
472
+ """
473
+ opsets = self.get_opset_import()
474
+ nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)]
475
+ for n in nodes:
476
+ if n.domain not in opsets:
477
+ opsets[n.domain] = 1 # TODO: how to get n.version?
478
+ opset_imports = [
479
+ onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
480
+ ]
481
+
482
+ attribute_names = [attr.name for attr in self.attrs if not attr.has_default]
483
+
484
+ f = helper.make_function(
485
+ self.domain,
486
+ self.name,
487
+ inputs=[x.name for x in self.inputs],
488
+ outputs=[y.name for y in self.outputs],
489
+ nodes=nodes,
490
+ opset_imports=opset_imports, # TODO
491
+ attributes=attribute_names,
492
+ doc_string=self.docstring,
493
+ )
494
+ # In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead
495
+ if hasattr(f, "attribute_proto"):
496
+ f.attribute_proto.extend(
497
+ [attr.attr_proto for attr in self.attrs if attr.has_default]
498
+ )
499
+ return f
500
+
501
+
502
+ # IRBuilder: abstracts out details of the IR in the python-to-IR converter
503
+
504
+
505
+ class IRBuilder:
506
+ def __init__(self):
507
+ self.functions = {}
508
+
509
+ def new_function(self, name: str, domain: str = "", register: bool = False) -> IRFunction:
510
+ if register and (domain, name) in self.functions:
511
+ raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.")
512
+ function = IRFunction(name, domain)
513
+ if register:
514
+ self.functions[domain, name] = function
515
+ return function
516
+
517
+ def add_docstring(self, fn: IRFunction, docstring: str):
518
+ fn.append_docstring(docstring)
519
+
520
+ def add_stmt(
521
+ self,
522
+ fn: IRFunction,
523
+ results: Sequence[str],
524
+ callee: values.Op,
525
+ args: Sequence[Optional[str]],
526
+ attrs: Sequence[IRAttributeValue],
527
+ sub_functions=None,
528
+ ) -> None:
529
+ stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
530
+ fn.append_stmt(stmt)
531
+
532
+ def add_input(
533
+ self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
534
+ ) -> None:
535
+ var = IRVar(varname, type, info)
536
+ fn.append_input(var)
537
+
538
+ def add_attr_parameter(
539
+ self,
540
+ fn: IRFunction,
541
+ varname: str,
542
+ attribute_type: onnx.AttributeProto.AttributeType,
543
+ default_value: int | float | str | None,
544
+ ) -> None:
545
+ fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
546
+
547
+ def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
548
+ var = IRVar(varname, typeinfo, sourceinfo)
549
+ fn.append_output(var)
550
+
551
+ def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue:
552
+ return IRAttributeValue(attrproto)
553
+
554
+ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
555
+ proto = onnx.AttributeProto()
556
+ proto.name = attrname
557
+ proto.ref_attr_name = refname
558
+ attr_type = ta.pytype_to_attrtype(pytype)
559
+ assert attr_type is not None
560
+ proto.type = attr_type
561
+ return IRAttributeValue(proto)
pythonProject/.venv/Lib/site-packages/onnxscript/main.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ # pylint disable: protected-access
4
+ from __future__ import annotations
5
+
6
+ import ast
7
+ import inspect
8
+ import sys
9
+ from typing import Any, Callable, Optional, Sequence, TypeVar
10
+
11
+ from typing_extensions import ParamSpec
12
+
13
+ import onnxscript
14
+ from onnxscript import converter, ir, irbuilder, values
15
+ from onnxscript._internal import ast_utils
16
+
17
+ _R = TypeVar("_R")
18
+ _P = ParamSpec("_P")
19
+
20
+
21
+ def script_check(
22
+ f: ast.FunctionDef,
23
+ opset: values.Opset,
24
+ global_names: dict[str, Any],
25
+ source: str,
26
+ default_opset: Optional[values.Opset] = None,
27
+ ) -> irbuilder.IRFunction:
28
+ """Check that a function falls into the ONNXScript subset of Python."""
29
+ # See if conversion succeeds.
30
+ # TODO: cleanup Converter interface/API, separating checker from
31
+ # converter
32
+ convert = converter.Converter(
33
+ opset=opset,
34
+ global_names=global_names,
35
+ source=source,
36
+ default_opset=default_opset,
37
+ )
38
+ return convert.translate_function_def(f)
39
+
40
+
41
+ def script(
42
+ opset: Optional[values.Opset] = None,
43
+ default_opset: Optional[values.Opset] = None,
44
+ **kwargs: Any,
45
+ ) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]:
46
+ """Main decorator. Declares a function as an onnx function.
47
+
48
+ Args:
49
+ opset: Opset the function belongs to (see :ref:`l-api-opsets`).
50
+ default_opset: Opset to use for operators not in the function's opset.
51
+ kwargs: Additional keyword arguments.
52
+
53
+ Returns:
54
+ an instance of :class:`onnxscript.values.OnnxFunction`
55
+
56
+ Example:
57
+ ::
58
+
59
+ @script()
60
+ def log2(x):
61
+ one = op.Constant(value=make_tensor('one', TensorProto.FLOAT, [1], [1]))
62
+ return op.Div(op.Log(x), op.CastLike(op.Log(cst), x))
63
+
64
+ Or:
65
+
66
+ ::
67
+
68
+ from onnxscript.onnx_opset import opset16
69
+
70
+ @script(opset16)
71
+ def log2(x):
72
+ one = op.Constant(value=make_tensor('one', TensorProto.FLOAT, [1], [1]))
73
+ return op.Div(op.Log(x), op.CastLike(op.Log(cst), x))
74
+ """
75
+ opset = opset or values.Opset("this", 1)
76
+ if not isinstance(opset, values.Opset):
77
+ raise TypeError(
78
+ "Script parameter must be an opset. Did you use @script instead of @script()?"
79
+ )
80
+
81
+ def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]:
82
+ if not inspect.isfunction(f):
83
+ raise TypeError("The ONNXScript decorator should be applied to functions only.")
84
+
85
+ src, f_ast = ast_utils.get_src_and_ast(f)
86
+ # The script should be compiled using the globals/locals at the definition site.
87
+ # This allows the script to reference names defined outside the script,
88
+ # which is used for a few different purposes.
89
+ # The following is an approximate solution that works for normal use.
90
+ module = inspect.getmodule(f)
91
+ closure = inspect.getclosurevars(f)
92
+ env = module.__dict__.copy()
93
+ env.update(closure.nonlocals)
94
+ result = script_check(f_ast, opset, env, src, default_opset=default_opset)
95
+ # TODO: add transformations.
96
+ return onnxscript.OnnxFunction(opset, f, result, src, kwargs)
97
+
98
+ return transform
99
+
100
+
101
+ def graph() -> Callable[[Callable], values.OnnxClosure]:
102
+ """A parametric decorator used to annotate nested-functions that are used
103
+ as graph-attributes.
104
+
105
+ Returns:
106
+ A decorator that returns its input function, but attaches a graph_proto
107
+ attribute representing the input function. The translation is not
108
+ done at this time, but previously when the outer-level function
109
+ was translated to an OnnxFunction. The decorator just looks up
110
+ and retrieves the GraphProto representation previously generated.
111
+
112
+ Example:
113
+ ::
114
+
115
+ @script()
116
+ def cumulative_sum(X: INT64['N']):
117
+
118
+ # Translation of cumulative_sum by @script will also translate Sum
119
+ # into a GraphProto, which will be stored in the OnnxFunction generated
120
+ # for cumulative_sum. At run-time (in eager-mode), the @graph decorator
121
+ # retrieves the pre-computed GraphProto and attaches it to the Sum function.
122
+ @graph()
123
+ def Sum(sum_in, next):
124
+ sum_out = sum_in + next
125
+ scan_out = op.Identity(sum_out)
126
+ return sum_out, scan_out
127
+ zero = op.Constant(value_int=0)
128
+ # The call to higher-order operator Scan below uses the above function
129
+ # Sum as a graph-attribute.
130
+ all_sum, result = op.Scan (zero, X, body=Sum, num_scan_inputs=1)
131
+ return result
132
+
133
+ """
134
+ # This is a bit fragile. We want to get the ONNXFunction object representing
135
+ # the outer-scope ONNXScript function from the execution stack. The caller of
136
+ # @graph is the original script function (cumulative_sum in the above example),
137
+ # and the caller of that function is the wrapper function/method in the
138
+ # corresponding OnnxFunction object.
139
+ # Currently, there is no support for eager-mode execution of nested functions,
140
+ # so we don't need to handle doubly nested functions (e.g., a function defined
141
+ # inside Sum in the above example).
142
+
143
+ function_frame = sys._getframe(1) # pylint: disable=protected-access
144
+ wrapper_frame = sys._getframe(3) # pylint: disable=protected-access
145
+ onnx_function = wrapper_frame.f_locals["self"]
146
+ nested_functions = onnx_function.function_ir.nested_functions
147
+
148
+ def transform(f: Callable) -> values.OnnxClosure:
149
+ return values.OnnxClosure(nested_functions[f.__name__], function_frame, f)
150
+
151
+ return transform
152
+
153
+
154
+ def is_converted_fun(f: Any) -> bool:
155
+ """Return True if f is a function converted by onnxscript decorator."""
156
+ return isinstance(f, onnxscript.OnnxFunction)
157
+
158
+
159
+ def export_onnx_lib(functions: Sequence[values.OnnxFunction], filename: str) -> None:
160
+ # Since we don't yet have LibProto defined, we use a ModelProto as a temporary
161
+ # container for the list of functions exported as a library, with an empty graph
162
+ # and dummy opset_imports.
163
+
164
+ # TODO(justinchuby): This function is not well supported. We should consider removing it
165
+ model = ir.Model(
166
+ ir.Graph(
167
+ inputs=[],
168
+ outputs=[],
169
+ nodes=[],
170
+ opset_imports={"": 15},
171
+ ),
172
+ functions=[ir.serde.deserialize_function(f.to_function_proto()) for f in functions],
173
+ ir_version=10,
174
+ producer_name="p2o",
175
+ )
176
+ ir.save(model, filename)
pythonProject/.venv/Lib/site-packages/onnxscript/onnx_types.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from __future__ import annotations
5
+
6
+ import abc
7
+ from typing import ClassVar, Optional, Tuple, Union
8
+
9
+ import onnx
10
+ import onnx_ir as ir
11
+
12
+ _DType = ir.DataType
13
+ _DimType = Union[int, str, type(None)]
14
+ _ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]
15
+
16
+ _tensor_type_shape_cache: dict[_DType, TensorType] = {}
17
+ tensor_type_registry: dict[_DType, TensorType] = {}
18
+
19
+
20
+ def _check_dim(dim):
21
+ if not isinstance(dim, (int, str, type(None))):
22
+ raise TypeError(f"Invalid dimension {dim}")
23
+
24
+
25
+ def _check_shape(shape):
26
+ if isinstance(shape, tuple):
27
+ for dim in shape:
28
+ _check_dim(dim)
29
+ elif shape != Ellipsis:
30
+ _check_dim(shape)
31
+
32
+
33
+ class TensorType(abc.ABC):
34
+ """ONNX Script representation of a tensor type supporting shape annotations.
35
+
36
+ A scalar-tensor of rank 0:
37
+ ::
38
+
39
+ tensor: FLOAT
40
+
41
+ A tensor of unknown rank:
42
+ ::
43
+
44
+ tensor: FLOAT[...]
45
+
46
+ A tensor of rank 2 of unknown dimensions, with symbolic names:
47
+ ::
48
+
49
+ tensor: FLOAT['M', 'N']
50
+
51
+ A tensor of rank 2 of known dimensions:
52
+ ::
53
+
54
+ tensor: FLOAT[128, 1024]
55
+ """
56
+
57
+ dtype: ClassVar[_DType]
58
+ shape: ClassVar[Optional[_ShapeType]]
59
+
60
+ def __new__(cls):
61
+ raise NotImplementedError("TensorTypes cannot be instantiated")
62
+
63
+ def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None):
64
+ cls.dtype = dtype
65
+ cls.shape = shape
66
+ if shape is None:
67
+ existing_cls = tensor_type_registry.get(dtype)
68
+ if existing_cls is not None:
69
+ raise ValueError(
70
+ f"Invalid usage: subclass {existing_cls!r} "
71
+ f"already defined for dtype={dtype}"
72
+ )
73
+ tensor_type_registry[dtype] = cls
74
+ else:
75
+ _check_shape(shape)
76
+
77
+ def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]:
78
+ if cls.shape is not None:
79
+ raise ValueError("Invalid usage: shape already specified.")
80
+ if shape is None:
81
+ # Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
82
+ shape = (None,)
83
+ key = (cls.dtype, shape)
84
+ shaped_type = _tensor_type_shape_cache.get(key)
85
+ if shaped_type is None:
86
+ shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape)
87
+ _tensor_type_shape_cache[key] = shaped_type
88
+ return shaped_type
89
+
90
+ @classmethod
91
+ def to_type_proto(cls) -> onnx.TypeProto:
92
+ if cls.shape is None:
93
+ shape = () # "FLOAT" is treated as a scalar
94
+ elif cls.shape is Ellipsis:
95
+ shape = None # "FLOAT[...]" is a tensor of unknown rank
96
+ elif isinstance(cls.shape, tuple):
97
+ shape = cls.shape # example: "FLOAT[10,20]"
98
+ else:
99
+ shape = [cls.shape] # example: "FLOAT[10]"
100
+ return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251
101
+
102
+ @classmethod
103
+ def to_string(cls) -> str:
104
+ return f"tensor({cls.__name__.lower()})"
105
+
106
+
107
+ class FLOAT(TensorType, dtype=ir.DataType.FLOAT):
108
+ pass
109
+
110
+
111
+ class UINT8(TensorType, dtype=ir.DataType.UINT8):
112
+ pass
113
+
114
+
115
+ class INT8(TensorType, dtype=ir.DataType.INT8):
116
+ pass
117
+
118
+
119
+ class UINT16(TensorType, dtype=ir.DataType.UINT16):
120
+ pass
121
+
122
+
123
+ class INT16(TensorType, dtype=ir.DataType.INT16):
124
+ pass
125
+
126
+
127
+ class INT32(TensorType, dtype=ir.DataType.INT32):
128
+ pass
129
+
130
+
131
+ class INT64(TensorType, dtype=ir.DataType.INT64):
132
+ pass
133
+
134
+
135
+ class STRING(TensorType, dtype=ir.DataType.STRING):
136
+ pass
137
+
138
+
139
+ class BOOL(TensorType, dtype=ir.DataType.BOOL):
140
+ pass
141
+
142
+
143
+ class FLOAT16(TensorType, dtype=ir.DataType.FLOAT16):
144
+ pass
145
+
146
+
147
+ class DOUBLE(TensorType, dtype=ir.DataType.DOUBLE):
148
+ pass
149
+
150
+
151
+ class UINT32(TensorType, dtype=ir.DataType.UINT32):
152
+ pass
153
+
154
+
155
+ class UINT64(TensorType, dtype=ir.DataType.UINT64):
156
+ pass
157
+
158
+
159
+ class COMPLEX64(TensorType, dtype=ir.DataType.COMPLEX64):
160
+ pass
161
+
162
+
163
+ class COMPLEX128(TensorType, dtype=ir.DataType.COMPLEX128):
164
+ pass
165
+
166
+
167
+ class BFLOAT16(TensorType, dtype=ir.DataType.BFLOAT16):
168
+ pass
169
+
170
+
171
+ class FLOAT8E4M3FN(TensorType, dtype=ir.DataType.FLOAT8E4M3FN):
172
+ pass
173
+
174
+
175
+ class FLOAT8E4M3FNUZ(TensorType, dtype=ir.DataType.FLOAT8E4M3FNUZ):
176
+ pass
177
+
178
+
179
+ class FLOAT8E5M2(TensorType, dtype=ir.DataType.FLOAT8E5M2):
180
+ pass
181
+
182
+
183
+ class FLOAT8E5M2FNUZ(TensorType, dtype=ir.DataType.FLOAT8E5M2FNUZ):
184
+ pass
185
+
186
+
187
+ class INT4(TensorType, dtype=ir.DataType.INT4):
188
+ pass
189
+
190
+
191
+ class UINT4(TensorType, dtype=ir.DataType.UINT4):
192
+ pass
193
+
194
+
195
+ class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1):
196
+ pass
197
+
198
+
199
+ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str:
200
+ """Converts an onnx type into the string representation of the type in *onnxscript*.
201
+
202
+ Args:
203
+ onnx_type: an instance of onnx TypeProto
204
+ reversible: if True, the conversion produces only types that are
205
+ recognized by the onnxscript converter.
206
+
207
+ Returns:
208
+ The string representation of the type in onnxscript
209
+
210
+ Raises:
211
+ ...
212
+ """
213
+ if onnx_type.HasField("tensor_type"):
214
+ elem_type = onnx_type.tensor_type.elem_type
215
+ name = onnx.TensorProto.DataType.Name(elem_type)
216
+ if onnx_type.tensor_type.HasField("shape"):
217
+ shape = []
218
+ for d in onnx_type.tensor_type.shape.dim:
219
+ if d.HasField("dim_value"):
220
+ shape.append(str(d.dim_value))
221
+ elif d.HasField("dim_param"):
222
+ shape.append(repr(d.dim_param))
223
+ else:
224
+ shape.append("None")
225
+ if not shape:
226
+ return name
227
+ return f"{name}[{','.join(shape)}]"
228
+ return f"{name}[...]"
229
+ if not reversible:
230
+ if onnx_type.HasField("sequence_type"):
231
+ elem_type = onnx_type.sequence_type.elem_type
232
+ return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]"
233
+ raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")
234
+
235
+
236
+ # Currently, only tensor types are supported. Need to expand support for other ONNX types.
237
+ ONNXType = TensorType
pythonProject/.venv/Lib/site-packages/onnxscript/py.typed ADDED
@@ -0,0 +1 @@
 
 
1
+ # Marker file for PEP-561 (inline types)
pythonProject/.venv/Lib/site-packages/onnxscript/sourceinfo.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # -------------------------------------------------------------------------
5
+
6
+ """Source code information used for diagnostic messages."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import ast
11
+ from typing import Callable, Optional
12
+
13
+
14
+ class SourceInfo:
15
+ """Information about onnxscript source fragment, used for diagnostic messages."""
16
+
17
+ def __init__(
18
+ self,
19
+ ast_node: ast.AST,
20
+ code: Optional[str] = None,
21
+ function_name: Optional[str] = None,
22
+ ):
23
+ self.ast_node = ast_node
24
+ self.code = code
25
+ self.function_name = function_name
26
+
27
+ @property
28
+ def lineno(self):
29
+ return self.ast_node.lineno
30
+
31
+ def msg(self, error_message: str) -> str:
32
+ lineno = self.lineno
33
+ if self.function_name:
34
+ source_loc = f"Function '{self.function_name}', line {lineno}"
35
+ else:
36
+ source_loc = f"Line {lineno}"
37
+
38
+ if self.code:
39
+ lines = self.code.split("\n")
40
+ line = lines[lineno - 1]
41
+ marker_prefix = " " * (self.ast_node.col_offset)
42
+ source_line = f"{line}\n{marker_prefix}^\n"
43
+ else:
44
+ source_line = ""
45
+
46
+ return f"ERROR: {error_message}\nat: {source_loc}\n{source_line}"
47
+
48
+ def __str__(self) -> str:
49
+ raise ValueError("Cannot happen!")
50
+
51
+
52
+ Formatter = Callable[[ast.AST, str], str]
53
+
54
+
55
+ def formatter(source_code: Optional[str]) -> Formatter:
56
+ def format(node: ast.AST, message: str) -> str:
57
+ return SourceInfo(node, source_code).msg(message)
58
+
59
+ return format
pythonProject/.venv/Lib/site-packages/onnxscript/tensor.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Any, Optional
7
+
8
+ import numpy as np
9
+
10
+ from onnxscript import ir, onnx_opset
11
+ from onnxscript._internal import autocast
12
+
13
+
14
+ class Tensor:
15
+ """An implementation of ONNX Tensors, based on a wrapper around numpy arrays.
16
+ Serves to define overloaded ops with an ONNX/ONNXScript semantics.
17
+ """
18
+
19
+ def __init__(self, nparray: Optional[np.ndarray], opset=None):
20
+ if nparray is not None and not isinstance(nparray, np.ndarray):
21
+ raise TypeError(
22
+ f"Unexpected type {type(nparray)}. It must be a numpy array or None."
23
+ )
24
+
25
+ self._nparray = nparray
26
+ # FIXME(justinhuby): Create a better way to determine the opset version
27
+ self._opset: Any = opset or onnx_opset.opset18
28
+
29
+ @property
30
+ def value(self) -> np.ndarray:
31
+ if self._nparray is None:
32
+ raise ValueError("Tensor does not have a value.")
33
+ return self._nparray
34
+
35
+ @property
36
+ def rank(self) -> int:
37
+ return len(self.value.shape)
38
+
39
+ @property
40
+ def is_scalar(self) -> bool:
41
+ return self.rank == 0
42
+
43
+ @property
44
+ def shape(self) -> tuple[int, ...]:
45
+ return self.value.shape
46
+
47
+ @property
48
+ def dtype(self) -> np.dtype:
49
+ return self.value.dtype
50
+
51
+ @property
52
+ def onnx_dtype(self) -> int:
53
+ return ir.DataType.from_numpy(self.dtype)
54
+
55
+ def __repr__(self) -> str:
56
+ return f"{self.__class__.__name__}({self.value!r})"
57
+
58
+ def __bool__(self) -> bool:
59
+ return bool(self.value)
60
+
61
+ def __int__(self) -> int:
62
+ return int(self.value)
63
+
64
+ def __float__(self) -> float:
65
+ return float(self.value)
66
+
67
+ def __len__(self) -> int:
68
+ return self.shape[0]
69
+
70
+ def __index__(self) -> int:
71
+ return self.value.__index__()
72
+
73
+ def __getitem__(self, index):
74
+ op = self._opset
75
+ if op.version < 13:
76
+ raise RuntimeError("Indexing requires opset 13 or later.")
77
+ if not isinstance(index, tuple):
78
+ # Normalize representation to a tuple.
79
+ # A single index-value is equivalent to a tuple with a single element.
80
+ index = (index,)
81
+ if len(index) > self.rank:
82
+ raise ValueError(
83
+ f"Number of indices {len(index)} is greater than rank {self.rank}"
84
+ )
85
+
86
+ # Promote integer indices to tensors of rank 0
87
+ index = [autocast.cast_pyvalue_to_os_tensor(x) for x in index]
88
+ # Process all elements in index
89
+ shape = self.shape
90
+ sliced_indices = []
91
+ scalar_indices = []
92
+ to_squeeze = []
93
+ non_scalar_indices = []
94
+ for axis_, s in enumerate(index):
95
+ if isinstance(s, slice):
96
+ if s.start is None and s.stop is None and s.step is None:
97
+ continue
98
+ if s.step is None or s.step > 0:
99
+ sliced_indices.append(
100
+ [
101
+ s.start or 0,
102
+ s.stop if s.stop is not None else shape[axis_],
103
+ axis_,
104
+ s.step or 1,
105
+ ]
106
+ )
107
+ else:
108
+ sliced_indices.append(
109
+ [
110
+ s.start if s.start is not None else (shape[axis_] - 1),
111
+ s.stop if s.stop is not None else -(shape[axis_] + 1),
112
+ axis_,
113
+ s.step,
114
+ ]
115
+ )
116
+ elif isinstance(s, Tensor):
117
+ if s.is_scalar:
118
+ scalar_indices.append([s, s + 1, axis_, 1])
119
+ to_squeeze.append(axis_)
120
+ else:
121
+ non_scalar_indices.append((axis_, s))
122
+ else:
123
+ raise TypeError(f"Unexpected type {type(s)}: slice or int expected.")
124
+
125
+ # Non-scalar-indexing requires the use of ONNX Gather operation.
126
+ # Slicing can be implemented efficiently using ONNX's Slice operation.
127
+ # Scalar-indexing can be implemented using either Gather or with the Slice operation.
128
+ # We map scalar-indexing into the Slice operation, except in the special case
129
+ # of a single scalar-index (with no other sliced_index), which we map directly
130
+ # to a Gather.
131
+
132
+ if not (sliced_indices or scalar_indices or non_scalar_indices):
133
+ # Edge case: no index specified. Eg. A[:, :]
134
+ return op.Identity(self)
135
+ if not sliced_indices and len(scalar_indices) == 1:
136
+ # Special case of indexing along a single axis: A[i], A[:, i], A[:, :, i] etc.
137
+ # promote integer input to tensor
138
+ axis = to_squeeze[0]
139
+ index_value = index[axis]
140
+ # use Gather to perform indexing
141
+ result = op.Gather(self, index_value, axis=axis)
142
+ elif sliced_indices or scalar_indices:
143
+ sliced_indices = sliced_indices + scalar_indices
144
+ indices = np.array(sliced_indices, dtype=np.int64).T
145
+ starts = Tensor(indices[0])
146
+ ends = Tensor(indices[1])
147
+ axes = Tensor(indices[2])
148
+ steps = Tensor(indices[3])
149
+ result = op.Slice(self, starts, ends, axes, steps)
150
+ if to_squeeze:
151
+ result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze)))
152
+ else:
153
+ result = self
154
+ for axis, value in non_scalar_indices:
155
+ result = op.Gather(result, value, axis=axis)
156
+
157
+ return result
158
+
159
+ def __mod__(self, other):
160
+ if self.onnx_dtype in {
161
+ ir.DataType.FLOAT,
162
+ ir.DataType.DOUBLE,
163
+ ir.DataType.FLOAT16,
164
+ ir.DataType.BFLOAT16,
165
+ }:
166
+ return self._opset.Mod(self, other, fmod=1)
167
+ return self._opset.Mod(self, other)
168
+
169
+ def __ne__(self, other):
170
+ temp = self._opset.Equal(self, other)
171
+ return self._opset.Not(temp)
172
+
173
+ def __neg__(self):
174
+ return self._opset.Neg(self)
175
+
176
+ def __add__(self, other):
177
+ return self._opset.Add(self, other)
178
+
179
+ def __radd__(self, other):
180
+ return self._opset.Add(other, self)
181
+
182
+ def __and__(self, other):
183
+ return self._opset.And(self, other)
184
+
185
+ def __rand__(self, other):
186
+ return self._opset.And(other, self)
187
+
188
+ def __mul__(self, other):
189
+ return self._opset.Mul(self, other)
190
+
191
+ def __rmul__(self, other):
192
+ return self._opset.Mul(other, self)
193
+
194
+ def __matmul__(self, other):
195
+ return self._opset.MatMul(self, other)
196
+
197
+ def __or__(self, other):
198
+ return self._opset.Or(self, other)
199
+
200
+ def __pow__(self, other):
201
+ return self._opset.Pow(self, other)
202
+
203
+ def __sub__(self, other):
204
+ return self._opset.Sub(self, other)
205
+
206
+ def __rsub__(self, other):
207
+ return self._opset.Sub(other, self)
208
+
209
+ def __truediv__(self, other):
210
+ return self._opset.Div(self, other)
211
+
212
+ def __lt__(self, other):
213
+ return self._opset.Less(self, other)
214
+
215
+ def __le__(self, other):
216
+ return self._opset.LessOrEqual(self, other)
217
+
218
+ def __eq__(self, other):
219
+ return self._opset.Equal(self, other)
220
+
221
+ def __ge__(self, other):
222
+ return self._opset.GreaterOrEqual(self, other)
223
+
224
+ def __gt__(self, other):
225
+ return self._opset.Greater(self, other)
pythonProject/.venv/Lib/site-packages/onnxscript/type_annotation.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import collections
6
+ import inspect
7
+ import typing
8
+ from typing import Optional, Sequence, Union
9
+
10
+ import onnx
11
+
12
+ from onnxscript import onnx_types
13
+ from onnxscript._internal import version_utils
14
+
15
+ # TypeAnnotationValue represents the (value of) valid type-annotations recognized
16
+ # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports
17
+ # - float, int, str (primitive attribute types)
18
+ # - Sequence[float], Sequence[int], Sequence[str] (attribute types)
19
+ # - Tensor types
20
+ # - Sequence[Tensor] types
21
+ # - Union of above 2
22
+ # - TypeVars with above bounds
23
+ # - Above types with annotation attached
24
+ TypeAnnotationValue = typing.Any
25
+
26
+ # Map from python type to corresponding ONNX AttributeProto type
27
+ _PYTYPE_TO_ATTRTYPE_MAP = {
28
+ float: onnx.AttributeProto.FLOAT,
29
+ int: onnx.AttributeProto.INT,
30
+ str: onnx.AttributeProto.STRING,
31
+ bool: onnx.AttributeProto.INT, # experimental
32
+ }
33
+
34
+ # Map from python type to corresponding ONNX AttributeProto type,
35
+ # for repeated (i.e., list of) values
36
+ _LISTTYPE_TO_ATTRTYPE_MAP = {
37
+ float: onnx.AttributeProto.FLOATS,
38
+ int: onnx.AttributeProto.INTS,
39
+ str: onnx.AttributeProto.STRINGS,
40
+ bool: onnx.AttributeProto.INTS, # experimental
41
+ }
42
+
43
+ _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])
44
+
45
+ # Map from ONNX AttributeProto type to its representation (in ONNX Script).
46
+ _ATTRTYPE_TO_REPR = {
47
+ onnx.AttributeProto.FLOAT: "float",
48
+ onnx.AttributeProto.INT: "int",
49
+ onnx.AttributeProto.STRING: "str",
50
+ onnx.AttributeProto.FLOATS: "Sequence[float]",
51
+ onnx.AttributeProto.INTS: "Sequence[int]",
52
+ onnx.AttributeProto.STRINGS: "Sequence[str]",
53
+ }
54
+
55
+
56
+ def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str:
57
+ if attr_type not in _ATTRTYPE_TO_REPR:
58
+ supported = ", ".join(
59
+ f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR
60
+ )
61
+ raise ValueError(f"Unsupported attribute type {attr_type}: only {supported} allowed.")
62
+ return _ATTRTYPE_TO_REPR[attr_type]
63
+
64
+
65
+ # A sorted list of all type strings used in an OpSchema
66
+ ALL_TENSOR_TYPE_STRINGS = tuple(
67
+ sorted(
68
+ tensor_type.to_string()
69
+ for tensor_type in onnx_types.tensor_type_registry.values()
70
+ # Skip FLOAT4E2M1 for versions older than 1.18
71
+ # TODO(after onnx requirement bump): Remove this check
72
+ if not (version_utils.onnx_older_than("1.18") and tensor_type == onnx_types.FLOAT4E2M1)
73
+ )
74
+ )
75
+
76
+
77
+ def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue:
78
+ """Remove Annotated wrapper if present, otherwise return typeinfo as is."""
79
+ if hasattr(typing, "Annotated"):
80
+ # Present in Python 3.9+
81
+ if typing.get_origin(typeinfo) is typing.Annotated:
82
+ return typing.get_args(typeinfo)[0]
83
+ return typeinfo
84
+
85
+
86
+ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
87
+ return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP
88
+
89
+
90
+ def pytype_to_attrtype(
91
+ pytype: TypeAnnotationValue,
92
+ ) -> Optional[onnx.AttributeProto.AttributeType]:
93
+ pytype = _remove_annotation(pytype)
94
+ if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
95
+ return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
96
+ type_constructor = typing.get_origin(pytype)
97
+ # Remove Optional wrapper if present, which is represented as an Union[..., type(None)]
98
+ if type_constructor is typing.Union:
99
+ # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)]
100
+ args = [x for x in typing.get_args(pytype) if x is not type(None)]
101
+ if len(args) == 1:
102
+ return pytype_to_attrtype(args[0])
103
+ if type_constructor in _LIST_CONSTRUCTORS:
104
+ elt_type = typing.get_args(pytype)[0]
105
+ if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
106
+ return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
107
+ return None
108
+
109
+
110
+ def base_type_is_bool(pytype: TypeAnnotationValue) -> bool:
111
+ """Returns True if base type of pytype is bool, False otherwise."""
112
+ pytype = _remove_annotation(pytype)
113
+ if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
114
+ return pytype is bool
115
+ type_constructor = typing.get_origin(pytype)
116
+ if type_constructor in _LIST_CONSTRUCTORS:
117
+ element_type = typing.get_args(pytype)[0]
118
+ return element_type is bool
119
+ # Remove Optional wrapper if present:
120
+ if type_constructor is Optional or type_constructor is Union:
121
+ # In Python < 3.10, Optional[X] is represented as Union[X, type(None)]
122
+ # so we filter out type(None) if present
123
+ args = [x for x in typing.get_args(pytype) if x is not type(None)]
124
+ if len(args) == 1:
125
+ return base_type_is_bool(args[0])
126
+
127
+ return False
128
+
129
+
130
+ def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:
131
+ if isinstance(typeinfo, onnx_types.TensorType):
132
+ return True
133
+ if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType):
134
+ return True
135
+ return False
136
+
137
+
138
+ def is_value_type(typeinfo: TypeAnnotationValue) -> bool:
139
+ """Returns True if typeinfo represents a value type, False if it is an attribute type.
140
+ Raises ValueError if typeinfo is not a supported type annotation.
141
+ """
142
+ typeinfo = _remove_annotation(typeinfo)
143
+ if _is_tensor_type(typeinfo):
144
+ return True
145
+ if _is_primitive_attr_type(typeinfo):
146
+ return False
147
+ type_constructor = typing.get_origin(typeinfo)
148
+ # Handle List-like type-constructor
149
+ # Eg. List[INT32] is a value type, while List[int] is an attribute type
150
+ if type_constructor in _LIST_CONSTRUCTORS:
151
+ elt_type = typing.get_args(typeinfo)[0]
152
+ return is_value_type(elt_type)
153
+ # Handle Union and Optional type-constructors
154
+ if type_constructor is typing.Union:
155
+ # Filter out None, since typing.Optional[X] evaluates to Union[X, None]
156
+ args = [x for x in typing.get_args(typeinfo) if x is not type(None)]
157
+ args_value_check = [is_value_type(x) for x in args]
158
+ if all(args_value_check):
159
+ # Handles cases like Optional[INT32] as well as Union[FLOAT16, FLOAT, DOUBLE]
160
+ return True
161
+ elif (len(args) == 1) and args_value_check[0] is False:
162
+ # Handle the case of optional attribute: eg. Optional[int]
163
+ # Note that we do not allow Union[int, float] for attributes.
164
+ return False
165
+ else:
166
+ raise ValueError(f"Unsupported type annotation '{typeinfo}'")
167
+ # Handle TypeVars:
168
+ if isinstance(typeinfo, typing.TypeVar):
169
+ if hasattr(typeinfo, "__bound__"):
170
+ bound = typeinfo.__bound__
171
+ return is_value_type(bound)
172
+ raise ValueError(f"Unsupported type annotation {typeinfo}")
173
+
174
+
175
+ def is_attr_type(pytype: TypeAnnotationValue):
176
+ return is_value_type(pytype) is False
177
+
178
+
179
+ def is_valid_type(typeinfo: TypeAnnotationValue):
180
+ try:
181
+ return is_value_type(typeinfo) in {True, False}
182
+ except ValueError:
183
+ return False
184
+
185
+
186
+ def is_optional(pytype) -> bool:
187
+ """Returns whether a pytype is an Optional."""
188
+ if typing.get_origin(pytype) is Union and type(None) in typing.get_args(pytype):
189
+ # Python < 3.10
190
+ return True
191
+ if typing.get_origin(pytype) is Optional:
192
+ # Python >= 3.10
193
+ return True
194
+ return False
195
+
196
+
197
+ def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]:
198
+ """Converts return-type annotation into a sequence of types.
199
+
200
+ The return type annotation can be either a single type (for a single output)
201
+ or a Tuple type (for multiple outputs). This function normalizes the
202
+ representation so that it is always a sequence of types, even for a single
203
+ output.
204
+ """
205
+ if isinstance(typeinfo, typing.Sequence):
206
+ return typeinfo
207
+ if typing.get_origin(typeinfo) is tuple:
208
+ return typing.get_args(typeinfo)
209
+ return (typeinfo,)
210
+
211
+
212
+ def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]:
213
+ """Returns a list of type-strings corresponding to a given type annotation.
214
+
215
+ Args:
216
+ pytype: A type annotation.
217
+
218
+ Returns:
219
+ A list of all supported input types for the given type annotation.
220
+ Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS.
221
+ """
222
+ if pytype is None:
223
+ return list(ALL_TENSOR_TYPE_STRINGS)
224
+ if pytype is onnx_types.TensorType:
225
+ return list(ALL_TENSOR_TYPE_STRINGS)
226
+ if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType):
227
+ return [pytype.to_string()]
228
+ if isinstance(pytype, onnx_types.TensorType):
229
+ return [pytype.to_string()]
230
+ if isinstance(pytype, typing.TypeVar):
231
+ constraints = pytype.__constraints__
232
+ if constraints:
233
+ return pytype_to_type_strings(Union.__getitem__(constraints)) # pylint: disable=unnecessary-dunder-call
234
+ bound = pytype.__bound__
235
+ if bound is None:
236
+ return list(ALL_TENSOR_TYPE_STRINGS)
237
+ return pytype_to_type_strings(bound)
238
+ if typing.get_origin(pytype) is Union:
239
+ options = []
240
+ subtypes = typing.get_args(pytype)
241
+ # A None type in a Union is equivalent to an optional type
242
+ optional = is_optional(pytype)
243
+ for subtype in subtypes:
244
+ if subtype is type(None):
245
+ # Skip None type because we are handling it with is_optional
246
+ continue
247
+ if optional:
248
+ options += [
249
+ *pytype_to_type_strings(subtype),
250
+ *[f"optional({s})" for s in pytype_to_type_strings(subtype)],
251
+ ]
252
+ else:
253
+ options += pytype_to_type_strings(subtype)
254
+ # Remove duplicates
255
+ return sorted(set(options))
256
+ if typing.get_origin(pytype) in _LIST_CONSTRUCTORS:
257
+ subtypes = typing.get_args(pytype)
258
+ return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])]
259
+
260
+ raise ValueError(f"Unsupported type: {pytype}")
261
+
262
+
263
+ def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]:
264
+ """Returns the name of the type constraint for a given type annotation.
265
+
266
+ Args:
267
+ pytype: A type annotation.
268
+
269
+ Returns:
270
+ The name of the type constraint if it is a TypeVar.
271
+ - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar].
272
+ - Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
273
+ - Returns None if the type annotation does not have a type constraint.
274
+ """
275
+ if isinstance(pytype, typing.TypeVar):
276
+ return pytype.__name__
277
+ if is_optional(pytype):
278
+ subtypes = typing.get_args(pytype)
279
+ for subtype in subtypes:
280
+ if subtype is type(None):
281
+ continue
282
+ type_param_name = get_type_constraint_name(subtype)
283
+ return f"Optional_{type_param_name}" if type_param_name else None
284
+ if typing.get_origin(pytype) in _LIST_CONSTRUCTORS:
285
+ subtypes = typing.get_args(pytype)
286
+ type_param_name = get_type_constraint_name(subtypes[0])
287
+ return f"Sequence_{type_param_name}" if type_param_name else None
288
+ return None