Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/METADATA +136 -0
- pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/RECORD +0 -0
- pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/WHEEL +5 -0
- pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/licenses/LICENSE +202 -0
- pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/top_level.txt +1 -0
- pythonProject/.venv/Lib/site-packages/onnx/test/cpp/utf8_conversion_test.cc +27 -0
- pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_conversion_test_base.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_downgrade_test.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_conversion_test_base.py +149 -0
- pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_downgrade_test.py +106 -0
- pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_upgrade_test.py +1964 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/converter.py +1462 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/evaluator.py +619 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/irbuilder.py +561 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/main.py +176 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/onnx_types.py +237 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/py.typed +1 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/sourceinfo.py +59 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/tensor.py +225 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/type_annotation.py +288 -0
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/METADATA
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: onnx
|
| 3 |
+
Version: 1.19.0
|
| 4 |
+
Summary: Open Neural Network Exchange
|
| 5 |
+
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
| 6 |
+
License: Apache License v2.0
|
| 7 |
+
Project-URL: Homepage, https://onnx.ai/
|
| 8 |
+
Project-URL: Repository, https://github.com/onnx/onnx
|
| 9 |
+
Classifier: Programming Language :: Python :: 3
|
| 10 |
+
Requires-Python: >=3.9
|
| 11 |
+
Description-Content-Type: text/markdown
|
| 12 |
+
License-File: LICENSE
|
| 13 |
+
Requires-Dist: numpy>=1.22
|
| 14 |
+
Requires-Dist: protobuf>=4.25.1
|
| 15 |
+
Requires-Dist: typing_extensions>=4.7.1
|
| 16 |
+
Requires-Dist: ml_dtypes
|
| 17 |
+
Provides-Extra: reference
|
| 18 |
+
Requires-Dist: Pillow; extra == "reference"
|
| 19 |
+
Dynamic: license-file
|
| 20 |
+
|
| 21 |
+
<!--
|
| 22 |
+
Copyright (c) ONNX Project Contributors
|
| 23 |
+
|
| 24 |
+
SPDX-License-Identifier: Apache-2.0
|
| 25 |
+
-->
|
| 26 |
+
|
| 27 |
+
<p align="center"><img width="40%" src="https://github.com/onnx/onnx/raw/main/docs/onnx-horizontal-color.png" /></p>
|
| 28 |
+
|
| 29 |
+
[](https://pypi.org/project/onnx)
|
| 30 |
+
[](https://github.com/onnx/onnx/actions/workflows/main.yml)
|
| 31 |
+
[](https://bestpractices.coreinfrastructure.org/projects/3313)
|
| 32 |
+
[](https://api.securityscorecards.dev/projects/github.com/onnx/onnx)
|
| 33 |
+
[](https://api.reuse.software/info/github.com/onnx/onnx)
|
| 34 |
+
[](https://github.com/astral-sh/ruff)
|
| 35 |
+
[](https://github.com/psf/black)
|
| 36 |
+
|
| 37 |
+
[Open Neural Network Exchange (ONNX)](https://onnx.ai) is an open ecosystem that empowers AI developers
|
| 38 |
+
to choose the right tools as their project evolves. ONNX provides an open source format for AI models, both deep learning and traditional ML. It defines an extensible computation graph model, as well as definitions of built-in operators and standard
|
| 39 |
+
data types. Currently we focus on the capabilities needed for inferencing (scoring).
|
| 40 |
+
|
| 41 |
+
ONNX is [widely supported](http://onnx.ai/supported-tools) and can be found in many frameworks, tools, and hardware. Enabling interoperability between different frameworks and streamlining the path from research to production helps increase the speed of innovation in the AI community. We invite the community to join us and further evolve ONNX.
|
| 42 |
+
|
| 43 |
+
# Use ONNX
|
| 44 |
+
|
| 45 |
+
* [Documentation of ONNX Python Package](https://onnx.ai/onnx/)
|
| 46 |
+
* [Tutorials for creating ONNX models](https://github.com/onnx/tutorials)
|
| 47 |
+
* [Pre-trained ONNX models](https://github.com/onnx/models)
|
| 48 |
+
|
| 49 |
+
# Learn about the ONNX spec
|
| 50 |
+
|
| 51 |
+
* [Overview](https://github.com/onnx/onnx/blob/main/docs/Overview.md)
|
| 52 |
+
* [ONNX intermediate representation spec](https://github.com/onnx/onnx/blob/main/docs/IR.md)
|
| 53 |
+
* [Versioning principles of the spec](https://github.com/onnx/onnx/blob/main/docs/Versioning.md)
|
| 54 |
+
* [Operators documentation](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
|
| 55 |
+
* [Operators documentation](https://onnx.ai/onnx/operators/index.html) (latest release)
|
| 56 |
+
* [Python API Overview](https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md)
|
| 57 |
+
|
| 58 |
+
# Programming utilities for working with ONNX Graphs
|
| 59 |
+
|
| 60 |
+
* [Shape and Type Inference](https://github.com/onnx/onnx/blob/main/docs/ShapeInference.md)
|
| 61 |
+
* [Graph Optimization](https://github.com/onnx/optimizer)
|
| 62 |
+
* [Opset Version Conversion](https://github.com/onnx/onnx/blob/main/docs/docsgen/source/api/version_converter.md)
|
| 63 |
+
|
| 64 |
+
# Contribute
|
| 65 |
+
|
| 66 |
+
ONNX is a community project and the open governance model is described [here](https://github.com/onnx/onnx/blob/main/community/readme.md). We encourage you to join the effort and contribute feedback, ideas, and code. You can participate in the [Special Interest Groups](https://github.com/onnx/onnx/blob/main/community/sigs.md) and [Working Groups](https://github.com/onnx/onnx/blob/main/community/working-groups.md) to shape the future of ONNX.
|
| 67 |
+
|
| 68 |
+
Check out our [contribution guide](https://github.com/onnx/onnx/blob/main/CONTRIBUTING.md) to get started.
|
| 69 |
+
|
| 70 |
+
If you think some operator should be added to ONNX specification, please read
|
| 71 |
+
[this document](https://github.com/onnx/onnx/blob/main/docs/AddNewOp.md).
|
| 72 |
+
|
| 73 |
+
# Community meetings
|
| 74 |
+
|
| 75 |
+
The schedules of the regular meetings of the Steering Committee, the working groups and the SIGs can be found [here](https://onnx.ai/calendar)
|
| 76 |
+
|
| 77 |
+
Community Meetups are held at least once a year. Content from previous community meetups are at:
|
| 78 |
+
|
| 79 |
+
* 2020.04.09 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14091402/LF+AI+Day+-ONNX+Community+Virtual+Meetup+-+Silicon+Valley+-+2020+April+9>
|
| 80 |
+
* 2020.10.14 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14092138/LF+AI+Day+-+ONNX+Community+Workshop+-+2020+October+14>
|
| 81 |
+
* 2021.03.24 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14092424/Instructions+for+Event+Hosts+-+LF+AI+Data+Day+-+ONNX+Virtual+Community+Meetup+-+March+2021>
|
| 82 |
+
* 2021.10.21 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14093194/LF+AI+Data+Day+ONNX+Community+Virtual+Meetup+-+October+2021>
|
| 83 |
+
* 2022.06.24 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14093969/ONNX+Community+Day+-+2022+June+24>
|
| 84 |
+
* 2023.06.28 <https://lf-aidata.atlassian.net/wiki/spaces/DL/pages/14094507/ONNX+Community+Day+2023+-+June+28>
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Discuss
|
| 89 |
+
|
| 90 |
+
We encourage you to open [Issues](https://github.com/onnx/onnx/issues), or use [Slack](https://lfaifoundation.slack.com/) (If you have not joined yet, please use this [link](https://join.slack.com/t/lfaifoundation/shared_invite/zt-o65errpw-gMTbwNr7FnNbVXNVFkmyNA) to join the group) for more real-time discussion.
|
| 91 |
+
|
| 92 |
+
# Follow Us
|
| 93 |
+
|
| 94 |
+
Stay up to date with the latest ONNX news. [[Facebook](https://www.facebook.com/onnxai/)] [[Twitter](https://twitter.com/onnxai)]
|
| 95 |
+
|
| 96 |
+
# Roadmap
|
| 97 |
+
|
| 98 |
+
A roadmap process takes place every year. More details can be found [here](https://github.com/onnx/steering-committee/tree/main/roadmap)
|
| 99 |
+
|
| 100 |
+
# Installation
|
| 101 |
+
|
| 102 |
+
ONNX released packages are published in PyPi.
|
| 103 |
+
|
| 104 |
+
```sh
|
| 105 |
+
pip install onnx # or pip install onnx[reference] for optional reference implementation dependencies
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
[ONNX weekly packages](https://pypi.org/project/onnx-weekly/) are published in PyPI to enable experimentation and early testing.
|
| 109 |
+
|
| 110 |
+
Detailed install instructions, including Common Build Options and Common Errors can be found [here](https://github.com/onnx/onnx/blob/main/INSTALL.md)
|
| 111 |
+
|
| 112 |
+
# Testing
|
| 113 |
+
|
| 114 |
+
ONNX uses [pytest](https://docs.pytest.org) as test driver. In order to run tests, you will first need to install `pytest`:
|
| 115 |
+
|
| 116 |
+
```sh
|
| 117 |
+
pip install pytest
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
After installing pytest, use the following command to run tests.
|
| 121 |
+
|
| 122 |
+
```sh
|
| 123 |
+
pytest
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
# Development
|
| 127 |
+
|
| 128 |
+
Check out the [contributor guide](https://github.com/onnx/onnx/blob/main/CONTRIBUTING.md) for instructions.
|
| 129 |
+
|
| 130 |
+
# License
|
| 131 |
+
|
| 132 |
+
[Apache License v2.0](LICENSE)
|
| 133 |
+
|
| 134 |
+
# Code of Conduct
|
| 135 |
+
|
| 136 |
+
[ONNX Open Source Code of Conduct](https://onnx.ai/codeofconduct.html)
|
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/RECORD
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: setuptools (1.19.0)
|
| 3 |
+
Root-Is-Purelib: false
|
| 4 |
+
Tag: cp310-cp310-win_amd64
|
| 5 |
+
|
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/licenses/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [yyyy] [name of copyright owner]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
pythonProject/.venv/Lib/site-packages/onnx-1.19.0.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
onnx
|
pythonProject/.venv/Lib/site-packages/onnx/test/cpp/utf8_conversion_test.cc
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) ONNX Project Contributors
|
| 2 |
+
|
| 3 |
+
/*
|
| 4 |
+
* SPDX-License-Identifier: Apache-2.0
|
| 5 |
+
*/
|
| 6 |
+
|
| 7 |
+
#ifdef _WIN32
|
| 8 |
+
#include <string>
|
| 9 |
+
|
| 10 |
+
#include "gtest/gtest.h"
|
| 11 |
+
#include "onnx/common/path.h"
|
| 12 |
+
namespace ONNX_NAMESPACE::Test {
|
| 13 |
+
|
| 14 |
+
TEST(UTF8Test, WideStringConvertion) {
|
| 15 |
+
std::string utf8_str(u8"世界,你好!");
|
| 16 |
+
EXPECT_EQ(ONNX_NAMESPACE::wstring_to_utf8str(ONNX_NAMESPACE::utf8str_to_wstring(utf8_str)), utf8_str);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
TEST(UTF8Test, TryConvertUTF8) {
|
| 20 |
+
std::string utf8_str(u8"世界,你好!");
|
| 21 |
+
auto wstr = ONNX_NAMESPACE::utf8str_to_wstring(utf8_str);
|
| 22 |
+
auto wstr2 = ONNX_NAMESPACE::utf8str_to_wstring(
|
| 23 |
+
std::string(reinterpret_cast<const char*>(wstr.c_str()), sizeof(std::wstring::value_type) * wstr.size()), true);
|
| 24 |
+
EXPECT_EQ(wstr, wstr2);
|
| 25 |
+
}
|
| 26 |
+
} // namespace ONNX_NAMESPACE::Test
|
| 27 |
+
#endif
|
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_conversion_test_base.cpython-310.pyc
ADDED
|
Binary file (6.24 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/__pycache__/automatic_downgrade_test.cpython-310.pyc
ADDED
|
Binary file (3.28 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_conversion_test_base.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ONNX Project Contributors
|
| 2 |
+
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import string
|
| 7 |
+
import unittest
|
| 8 |
+
from typing import TYPE_CHECKING, Any, cast
|
| 9 |
+
|
| 10 |
+
import onnx
|
| 11 |
+
from onnx import TensorProto, ValueInfoProto, helper, shape_inference, version_converter
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from collections.abc import Sequence
|
| 15 |
+
|
| 16 |
+
LATEST_OPSET = onnx.defs.onnx_opset_version()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestAutomaticConversion(unittest.TestCase):
|
| 20 |
+
def _test_model_conversion(
|
| 21 |
+
self, to_opset: int, model: str | onnx.ModelProto
|
| 22 |
+
) -> None:
|
| 23 |
+
if isinstance(model, str):
|
| 24 |
+
model = onnx.parser.parse_model(model)
|
| 25 |
+
onnx.checker.check_model(model)
|
| 26 |
+
shape_inference.infer_shapes(model, strict_mode=True)
|
| 27 |
+
|
| 28 |
+
converted = version_converter.convert_version(model, to_opset)
|
| 29 |
+
onnx.checker.check_model(converted)
|
| 30 |
+
shape_inference.infer_shapes(converted, strict_mode=True)
|
| 31 |
+
|
| 32 |
+
def _test_model_conversion_fails(
|
| 33 |
+
self, to_opset: int, model: str | onnx.ModelProto
|
| 34 |
+
) -> None:
|
| 35 |
+
if isinstance(model, str):
|
| 36 |
+
model = onnx.parser.parse_model(model)
|
| 37 |
+
onnx.checker.check_model(model)
|
| 38 |
+
shape_inference.infer_shapes(model, strict_mode=True)
|
| 39 |
+
|
| 40 |
+
with self.assertRaises(RuntimeError):
|
| 41 |
+
version_converter.convert_version(model, to_opset)
|
| 42 |
+
|
| 43 |
+
def _test_op_conversion(
|
| 44 |
+
self,
|
| 45 |
+
op: str,
|
| 46 |
+
from_opset: int,
|
| 47 |
+
input_shapes: Sequence[Sequence[int | None] | str] = ((3, 4, 5),),
|
| 48 |
+
output_shapes: Sequence[Sequence[int | None]] = ((3, 4, 5),),
|
| 49 |
+
input_types: Sequence[Any] | None = None,
|
| 50 |
+
output_types: Sequence[Any] | None = None,
|
| 51 |
+
initializer: Sequence[Any] = (),
|
| 52 |
+
attrs: dict[str, Any] | None = None,
|
| 53 |
+
seq_inputs: Sequence[int] = (),
|
| 54 |
+
seq_outputs: Sequence[int] = (),
|
| 55 |
+
optional_inputs: Sequence[int] = (),
|
| 56 |
+
optional_outputs: Sequence[int] = (),
|
| 57 |
+
is_upgrade: bool = True,
|
| 58 |
+
) -> None:
|
| 59 |
+
"""Test conversion.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
op: A string representing the name of the operator to test.
|
| 63 |
+
from_opset: An integer representing the lowest opset version to convert.
|
| 64 |
+
input_shapes: A sequence of tuples or strings representing the shapes of the input tensors.
|
| 65 |
+
The default value is ((3, 4, 5),).
|
| 66 |
+
output_shapes: A sequence of tuples representing the shapes of the output tensors.
|
| 67 |
+
The default value is ((3, 4, 5),).
|
| 68 |
+
input_types: An optional sequence of types representing the data types of the input tensors.
|
| 69 |
+
output_types: An optional sequence of types representing the data types of the output tensors.
|
| 70 |
+
initializer: A sequence of values representing the initial values of the input tensors.
|
| 71 |
+
attrs: An optional dictionary of attributes for the operator.
|
| 72 |
+
seq_inputs: A sequence of integers representing the indices of the input tensors that are sequences.
|
| 73 |
+
seq_outputs: A sequence of integers representing the indices of the output tensors that are sequences.
|
| 74 |
+
optional_inputs: A sequence of integers representing the indices of the input tensors that are optional.
|
| 75 |
+
optional_outputs: A sequence of integers representing the indices of the output tensors that are optional.
|
| 76 |
+
is_upgrade: A boolean value indicating whether to run the version converter from from_opset to
|
| 77 |
+
the most recent opset version (True) or from the most recent opset version to from_opset (False).
|
| 78 |
+
The default value is True. In both cases, runs checker and shape inference on the final model.
|
| 79 |
+
"""
|
| 80 |
+
if attrs is None:
|
| 81 |
+
attrs = {}
|
| 82 |
+
|
| 83 |
+
n_inputs = len(input_shapes)
|
| 84 |
+
letters = list(string.ascii_lowercase)[:n_inputs]
|
| 85 |
+
input_names = [
|
| 86 |
+
letter if shape != "" else ""
|
| 87 |
+
for (letter, shape) in zip(letters, input_shapes)
|
| 88 |
+
]
|
| 89 |
+
if input_types is None:
|
| 90 |
+
input_types = [TensorProto.FLOAT] * n_inputs
|
| 91 |
+
is_sequence = [0 if id not in seq_inputs else 1 for id in range(n_inputs)]
|
| 92 |
+
is_optional = [0 if id not in optional_inputs else 1 for id in range(n_inputs)]
|
| 93 |
+
# turn empty strings into [0] to ease type analysis, even though those entries
|
| 94 |
+
# will be ignored
|
| 95 |
+
input_shapes_cast = cast(
|
| 96 |
+
"list[list[int]]",
|
| 97 |
+
[[0] if isinstance(shape, str) else shape for shape in input_shapes],
|
| 98 |
+
)
|
| 99 |
+
inputs: list[ValueInfoProto] = []
|
| 100 |
+
for name, ttype, shape, is_seq, is_opt in zip(
|
| 101 |
+
input_names, input_types, input_shapes_cast, is_sequence, is_optional
|
| 102 |
+
):
|
| 103 |
+
if name != "":
|
| 104 |
+
if is_seq:
|
| 105 |
+
inputs += [
|
| 106 |
+
helper.make_tensor_sequence_value_info(name, ttype, shape)
|
| 107 |
+
]
|
| 108 |
+
elif is_opt:
|
| 109 |
+
type_proto = helper.make_tensor_type_proto(ttype, shape)
|
| 110 |
+
optional_type_proto = helper.make_optional_type_proto(type_proto)
|
| 111 |
+
inputs += [helper.make_value_info(name, optional_type_proto)]
|
| 112 |
+
else:
|
| 113 |
+
inputs += [helper.make_tensor_value_info(name, ttype, shape)]
|
| 114 |
+
|
| 115 |
+
n_outputs = len(output_shapes)
|
| 116 |
+
output_names = list(string.ascii_lowercase)[n_inputs : n_inputs + n_outputs]
|
| 117 |
+
if output_types is None:
|
| 118 |
+
output_types = [TensorProto.FLOAT] * n_outputs
|
| 119 |
+
is_sequence = [0 if id not in seq_outputs else 1 for id in range(n_outputs)]
|
| 120 |
+
is_optional = [
|
| 121 |
+
0 if id not in optional_outputs else 1 for id in range(n_outputs)
|
| 122 |
+
]
|
| 123 |
+
output_shapes_cast = cast(
|
| 124 |
+
"list[list[int]]",
|
| 125 |
+
[[0] if isinstance(shape, str) else shape for shape in output_shapes],
|
| 126 |
+
)
|
| 127 |
+
outputs: list[ValueInfoProto] = []
|
| 128 |
+
for name, ttype, shape, is_seq, is_opt in zip(
|
| 129 |
+
output_names, output_types, output_shapes_cast, is_sequence, is_optional
|
| 130 |
+
):
|
| 131 |
+
if is_seq:
|
| 132 |
+
outputs += [helper.make_tensor_sequence_value_info(name, ttype, shape)]
|
| 133 |
+
elif is_opt:
|
| 134 |
+
type_proto = helper.make_tensor_type_proto(ttype, shape)
|
| 135 |
+
optional_type_proto = helper.make_optional_type_proto(type_proto)
|
| 136 |
+
outputs += [helper.make_value_info(name, optional_type_proto)]
|
| 137 |
+
else:
|
| 138 |
+
outputs += [helper.make_tensor_value_info(name, ttype, shape)]
|
| 139 |
+
|
| 140 |
+
node = helper.make_node(op, input_names, output_names, **attrs)
|
| 141 |
+
graph = helper.make_graph([node], op, inputs, outputs, initializer)
|
| 142 |
+
start_opset = from_opset if is_upgrade else LATEST_OPSET
|
| 143 |
+
end_opset = LATEST_OPSET if is_upgrade else from_opset
|
| 144 |
+
original = helper.make_model(
|
| 145 |
+
graph,
|
| 146 |
+
producer_name="test",
|
| 147 |
+
opset_imports=[helper.make_opsetid("", start_opset)],
|
| 148 |
+
)
|
| 149 |
+
self._test_model_conversion(end_opset, original)
|
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_downgrade_test.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ONNX Project Contributors
|
| 2 |
+
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import unittest
|
| 7 |
+
|
| 8 |
+
import automatic_conversion_test_base
|
| 9 |
+
import numpy as np
|
| 10 |
+
import parameterized
|
| 11 |
+
|
| 12 |
+
import onnx
|
| 13 |
+
from onnx import helper
|
| 14 |
+
|
| 15 |
+
#####################################################################################
|
| 16 |
+
# Every test calls _test_op_conversion to downgrade a model from the most recent opset version
|
| 17 |
+
# to a early version and runs checker + shape inference on the downgraded model.
|
| 18 |
+
####################################################################################
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TestAutomaticDowngrade(automatic_conversion_test_base.TestAutomaticConversion):
|
| 22 |
+
def _test_op_downgrade(self, op: str, *args, **kwargs):
|
| 23 |
+
self._test_op_conversion(op, *args, **kwargs, is_upgrade=False)
|
| 24 |
+
|
| 25 |
+
@parameterized.parameterized.expand(
|
| 26 |
+
[
|
| 27 |
+
"ReduceL1",
|
| 28 |
+
"ReduceL2",
|
| 29 |
+
"ReduceLogSum",
|
| 30 |
+
"ReduceLogSumExp",
|
| 31 |
+
"ReduceMean",
|
| 32 |
+
"ReduceMax",
|
| 33 |
+
"ReduceMin",
|
| 34 |
+
"ReduceProd",
|
| 35 |
+
"ReduceSum",
|
| 36 |
+
"ReduceSumSquare",
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
def test_reduce_ops(self, op) -> None:
|
| 40 |
+
# TODO: need to add test cases for missing axes input which depends on this pr:
|
| 41 |
+
# https://github.com/onnx/onnx/pull/5613
|
| 42 |
+
axes = helper.make_tensor(
|
| 43 |
+
"b", onnx.TensorProto.INT64, dims=[3], vals=np.array([0, 1, 2])
|
| 44 |
+
)
|
| 45 |
+
self._test_op_downgrade(
|
| 46 |
+
op,
|
| 47 |
+
from_opset=13,
|
| 48 |
+
input_shapes=[[3, 4, 5], [3]],
|
| 49 |
+
output_shapes=[[1, 1, 1]],
|
| 50 |
+
input_types=[onnx.TensorProto.FLOAT, onnx.TensorProto.INT64],
|
| 51 |
+
initializer=[axes],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def test_dft20_no_axis(self) -> None:
|
| 55 |
+
self._test_model_conversion(
|
| 56 |
+
to_opset=19,
|
| 57 |
+
model="""
|
| 58 |
+
<ir_version: 9, opset_import: [ "" : 20]>
|
| 59 |
+
dft_no_axis (float[N, M, 1] x) => (float[N, M, 2] y)
|
| 60 |
+
{
|
| 61 |
+
y = DFT (x)
|
| 62 |
+
}
|
| 63 |
+
""",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def test_dft20_initializer_axis(self) -> None:
|
| 67 |
+
self._test_model_conversion(
|
| 68 |
+
to_opset=19,
|
| 69 |
+
model="""
|
| 70 |
+
<ir_version: 9, opset_import: [ "" : 20]>
|
| 71 |
+
dft_no_axis (float[N, M, 1] x, int64 dft_length) => (float[N, K, 2] y)
|
| 72 |
+
<int64 axis = {1}>
|
| 73 |
+
{
|
| 74 |
+
y = DFT (x, dft_length, axis)
|
| 75 |
+
}
|
| 76 |
+
""",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def test_dft20_constant_axis(self) -> None:
|
| 80 |
+
self._test_model_conversion(
|
| 81 |
+
to_opset=19,
|
| 82 |
+
model="""
|
| 83 |
+
<ir_version: 9, opset_import: [ "" : 20]>
|
| 84 |
+
dft_no_axis (float[N, M, 1] x, int64 dft_length) => (float[N, K, 2] y)
|
| 85 |
+
{
|
| 86 |
+
axis = Constant <value = int64{1}>()
|
| 87 |
+
y = DFT (x, dft_length, axis)
|
| 88 |
+
}
|
| 89 |
+
""",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def test_dft20_unknown_axis(self) -> None:
|
| 93 |
+
self._test_model_conversion_fails(
|
| 94 |
+
to_opset=19,
|
| 95 |
+
model="""
|
| 96 |
+
<ir_version: 9, opset_import: [ "" : 20]>
|
| 97 |
+
dft_no_axis (float[N, M, 1] x, int64 dft_length, int64 axis) => (float[P, K, 2] y)
|
| 98 |
+
{
|
| 99 |
+
y = DFT (x, dft_length, axis)
|
| 100 |
+
}
|
| 101 |
+
""",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
unittest.main()
|
pythonProject/.venv/Lib/site-packages/onnx/test/version_converter/automatic_upgrade_test.py
ADDED
|
@@ -0,0 +1,1964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ONNX Project Contributors
|
| 2 |
+
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import unittest
|
| 7 |
+
|
| 8 |
+
import automatic_conversion_test_base
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import onnx
|
| 12 |
+
from onnx import TensorProto, helper
|
| 13 |
+
|
| 14 |
+
#####################################################################################
|
| 15 |
+
# Every test calls _test_op_conversion to upgrade a model from an initial opset version
|
| 16 |
+
# to the most recent version and runs checker and shape inference on the final upgraded model.
|
| 17 |
+
####################################################################################
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestAutomaticUpgrade(automatic_conversion_test_base.TestAutomaticConversion):
|
| 21 |
+
@classmethod
|
| 22 |
+
def setUpClass(cls):
|
| 23 |
+
cls.tested_ops = []
|
| 24 |
+
|
| 25 |
+
def _test_op_upgrade(self, op, *args, **kwargs):
|
| 26 |
+
self.tested_ops.append(op)
|
| 27 |
+
self._test_op_conversion(op, *args, **kwargs, is_upgrade=True)
|
| 28 |
+
|
| 29 |
+
def test_Abs(self) -> None:
|
| 30 |
+
self._test_op_upgrade("Abs", 1, attrs={"consumed_inputs": [0]})
|
| 31 |
+
|
| 32 |
+
def test_Acosh(self) -> None:
|
| 33 |
+
self._test_op_upgrade("Acosh", 9)
|
| 34 |
+
|
| 35 |
+
def test_Acos(self) -> None:
|
| 36 |
+
self._test_op_upgrade("Acos", 7)
|
| 37 |
+
|
| 38 |
+
def test_And(self) -> None:
|
| 39 |
+
# 6->7 adapter is missing
|
| 40 |
+
self._test_op_upgrade(
|
| 41 |
+
"And",
|
| 42 |
+
7,
|
| 43 |
+
[[2, 3], [2, 3]],
|
| 44 |
+
[[2, 3]],
|
| 45 |
+
[TensorProto.BOOL, TensorProto.BOOL],
|
| 46 |
+
[TensorProto.BOOL],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def test_Asinh(self) -> None:
|
| 50 |
+
self._test_op_upgrade("Asinh", 9)
|
| 51 |
+
|
| 52 |
+
def test_Atanh(self) -> None:
|
| 53 |
+
self._test_op_upgrade("Atanh", 9)
|
| 54 |
+
|
| 55 |
+
def test_Add_1(self) -> None:
|
| 56 |
+
self._test_op_upgrade(
|
| 57 |
+
"Add", 1, [[3, 4, 5], [3, 4, 5]], attrs={"consumed_inputs": [0]}
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def test_Add_2(self) -> None:
|
| 61 |
+
self._test_op_upgrade(
|
| 62 |
+
"Add", 1, [[3, 4, 5], [5]], attrs={"consumed_inputs": [0], "broadcast": 1}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def test_Add_3(self) -> None:
|
| 66 |
+
self._test_op_upgrade(
|
| 67 |
+
"Add",
|
| 68 |
+
1,
|
| 69 |
+
[[3, 4, 5], [3]],
|
| 70 |
+
attrs={"consumed_inputs": [0], "broadcast": 1, "axis": 0},
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def test_AffineGrid_2D(self) -> None:
|
| 74 |
+
N, _, H, W = 2, 3, 5, 6
|
| 75 |
+
self._test_op_upgrade("AffineGrid", 20, [[N, 2, 3], [4]], [[N, H, W, 2]])
|
| 76 |
+
|
| 77 |
+
def test_AffineGrid_3D(self) -> None:
|
| 78 |
+
N, _, D, H, W = 2, 3, 4, 5, 6
|
| 79 |
+
self._test_op_upgrade("AffineGrid", 20, [[N, 3, 4], [5]], [[N, D, H, W, 3]])
|
| 80 |
+
|
| 81 |
+
def test_ArgMax_1(self) -> None:
|
| 82 |
+
self._test_op_upgrade(
|
| 83 |
+
"ArgMax", 7, [[2, 3, 4]], [[1, 3, 4]], output_types=[TensorProto.INT64]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def test_ArgMax_2(self) -> None:
|
| 87 |
+
self._test_op_upgrade(
|
| 88 |
+
"ArgMax",
|
| 89 |
+
7,
|
| 90 |
+
[[2, 3, 4]],
|
| 91 |
+
[[2, 1, 4]],
|
| 92 |
+
output_types=[TensorProto.INT64],
|
| 93 |
+
attrs={"axis": 1},
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def test_ArgMin_1(self) -> None:
|
| 97 |
+
self._test_op_upgrade(
|
| 98 |
+
"ArgMin", 7, [[2, 3, 4]], [[1, 3, 4]], output_types=[TensorProto.INT64]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def test_ArgMin_2(self) -> None:
|
| 102 |
+
self._test_op_upgrade(
|
| 103 |
+
"ArgMin",
|
| 104 |
+
7,
|
| 105 |
+
[[2, 3, 4]],
|
| 106 |
+
[[2, 1, 4]],
|
| 107 |
+
output_types=[TensorProto.INT64],
|
| 108 |
+
attrs={"axis": 1},
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def test_Asin(self) -> None:
|
| 112 |
+
self._test_op_upgrade("Asin", 7)
|
| 113 |
+
|
| 114 |
+
def test_Atan(self) -> None:
|
| 115 |
+
self._test_op_upgrade("Atan", 7)
|
| 116 |
+
|
| 117 |
+
def test_Attention_1(self) -> None:
|
| 118 |
+
self._test_op_upgrade(
|
| 119 |
+
"Attention",
|
| 120 |
+
23,
|
| 121 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
|
| 122 |
+
[[2, 3, 4, 8]],
|
| 123 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 124 |
+
[TensorProto.FLOAT],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def test_Attention_2(self) -> None:
|
| 128 |
+
self._test_op_upgrade(
|
| 129 |
+
"Attention",
|
| 130 |
+
23,
|
| 131 |
+
[[2, 9, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
|
| 132 |
+
[[2, 9, 4, 8]],
|
| 133 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 134 |
+
[TensorProto.FLOAT],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def test_Attention_3(self) -> None:
|
| 138 |
+
self._test_op_upgrade(
|
| 139 |
+
"Attention",
|
| 140 |
+
23,
|
| 141 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 10]],
|
| 142 |
+
[[2, 3, 4, 10]],
|
| 143 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 144 |
+
[TensorProto.FLOAT],
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def test_Attention_4(self) -> None:
|
| 148 |
+
self._test_op_upgrade(
|
| 149 |
+
"Attention",
|
| 150 |
+
23,
|
| 151 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
|
| 152 |
+
[[2, 3, 4, 8]],
|
| 153 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 154 |
+
[TensorProto.FLOAT],
|
| 155 |
+
attrs={"scale": 2.0},
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def test_Attention_5(self) -> None:
|
| 159 |
+
self._test_op_upgrade(
|
| 160 |
+
"Attention",
|
| 161 |
+
23,
|
| 162 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
|
| 163 |
+
[[2, 3, 4, 8]],
|
| 164 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 165 |
+
[TensorProto.FLOAT],
|
| 166 |
+
attrs={"is_causal": 1},
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def test_Attention_6(self) -> None:
|
| 170 |
+
self._test_op_upgrade(
|
| 171 |
+
"Attention",
|
| 172 |
+
23,
|
| 173 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8], [4, 6]],
|
| 174 |
+
[[2, 3, 4, 8]],
|
| 175 |
+
[
|
| 176 |
+
TensorProto.FLOAT,
|
| 177 |
+
TensorProto.FLOAT,
|
| 178 |
+
TensorProto.FLOAT,
|
| 179 |
+
TensorProto.FLOAT,
|
| 180 |
+
],
|
| 181 |
+
[TensorProto.FLOAT],
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def test_Attention_7(self) -> None:
|
| 185 |
+
self._test_op_upgrade(
|
| 186 |
+
"Attention",
|
| 187 |
+
23,
|
| 188 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8], [4, 6]],
|
| 189 |
+
[[2, 3, 4, 8]],
|
| 190 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.BOOL],
|
| 191 |
+
[TensorProto.FLOAT],
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def test_Attention_8(self) -> None:
|
| 195 |
+
self._test_op_upgrade(
|
| 196 |
+
"Attention",
|
| 197 |
+
23,
|
| 198 |
+
[[2, 3, 4, 8], [2, 3, 6, 8], [2, 3, 6, 8]],
|
| 199 |
+
[[2, 3, 4, 8]],
|
| 200 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 201 |
+
[TensorProto.FLOAT],
|
| 202 |
+
attrs={"softcap": 2.0},
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def test_AveragePool(self) -> None:
|
| 206 |
+
self._test_op_upgrade(
|
| 207 |
+
"AveragePool",
|
| 208 |
+
1,
|
| 209 |
+
[[1, 1, 5, 5]],
|
| 210 |
+
[[1, 1, 4, 4]],
|
| 211 |
+
attrs={"kernel_shape": [2, 2]},
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def test_Bernoulli(self) -> None:
|
| 215 |
+
self._test_op_upgrade("Bernoulli", 15)
|
| 216 |
+
|
| 217 |
+
def test_BitShift(self) -> None:
|
| 218 |
+
self._test_op_upgrade(
|
| 219 |
+
"BitShift",
|
| 220 |
+
11,
|
| 221 |
+
[[2, 3], [2, 3]],
|
| 222 |
+
[[2, 3]],
|
| 223 |
+
[TensorProto.UINT8, TensorProto.UINT8],
|
| 224 |
+
[TensorProto.UINT8],
|
| 225 |
+
attrs={"direction": "RIGHT"},
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def test_BatchNormalization_1(self) -> None:
|
| 229 |
+
self._test_op_upgrade(
|
| 230 |
+
"BatchNormalization",
|
| 231 |
+
1,
|
| 232 |
+
[[1, 3], [3], [3], [3], [3]],
|
| 233 |
+
[[1, 3]],
|
| 234 |
+
attrs={"consumed_inputs": [1, 1], "is_test": 1, "spatial": 1},
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def test_BatchNormalization_2(self) -> None:
|
| 238 |
+
self._test_op_upgrade(
|
| 239 |
+
"BatchNormalization",
|
| 240 |
+
14,
|
| 241 |
+
[[1, 3], [3], [3], [3], [3]],
|
| 242 |
+
[[1, 3], [3], [3]],
|
| 243 |
+
attrs={"training_mode": 1},
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def test_Cast(self) -> None:
|
| 247 |
+
# 5->6 adapter is missing
|
| 248 |
+
self._test_op_upgrade(
|
| 249 |
+
"Cast", 6, [[2, 3]], [[2, 3]], [TensorProto.INT64], attrs={"to": 1}
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def test_Ceil(self) -> None:
|
| 253 |
+
self._test_op_upgrade("Ceil", 1, attrs={"consumed_inputs": [0]})
|
| 254 |
+
|
| 255 |
+
def test_Celu(self) -> None:
|
| 256 |
+
self._test_op_upgrade("Celu", 12)
|
| 257 |
+
|
| 258 |
+
def test_Clip_1(self) -> None:
|
| 259 |
+
self._test_op_upgrade("Clip", 1, attrs={"consumed_inputs": [0]})
|
| 260 |
+
|
| 261 |
+
def test_Clip_2(self) -> None:
|
| 262 |
+
self._test_op_upgrade("Clip", 1, attrs={"consumed_inputs": [0], "min": -1.4})
|
| 263 |
+
|
| 264 |
+
def test_Clip_3(self) -> None:
|
| 265 |
+
self._test_op_upgrade("Clip", 1, attrs={"consumed_inputs": [0], "max": 2.6})
|
| 266 |
+
|
| 267 |
+
def test_Clip_4(self) -> None:
|
| 268 |
+
self._test_op_upgrade(
|
| 269 |
+
"Clip", 1, attrs={"consumed_inputs": [0], "min": -1.4, "max": 2.6}
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def test_Col2Im_4D(self) -> None:
|
| 273 |
+
self._test_op_upgrade("Col2Im", 18, [[1, 5, 5], [2], [2]], [[1, 1, 5, 5]])
|
| 274 |
+
|
| 275 |
+
def test_Col2Im_5D(self) -> None:
|
| 276 |
+
self._test_op_upgrade("Col2Im", 18, [[1, 10, 12], [3], [3]], [[1, 2, 3, 4, 5]])
|
| 277 |
+
|
| 278 |
+
def test_Compress(self) -> None:
|
| 279 |
+
self._test_op_upgrade(
|
| 280 |
+
"Compress",
|
| 281 |
+
9,
|
| 282 |
+
[[6, 7], [3]],
|
| 283 |
+
[[3]],
|
| 284 |
+
[TensorProto.FLOAT, TensorProto.BOOL],
|
| 285 |
+
[TensorProto.FLOAT],
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
def test_Concat(self) -> None:
|
| 289 |
+
self._test_op_upgrade("Concat", 1, [[2, 3], [2, 4]], [[2, 7]])
|
| 290 |
+
|
| 291 |
+
def test_constant(self) -> None:
|
| 292 |
+
value = helper.make_tensor(
|
| 293 |
+
"Value",
|
| 294 |
+
TensorProto.FLOAT,
|
| 295 |
+
dims=[3, 4, 5],
|
| 296 |
+
vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
|
| 297 |
+
raw=True,
|
| 298 |
+
)
|
| 299 |
+
self._test_op_upgrade("Constant", 1, [], attrs={"value": value})
|
| 300 |
+
|
| 301 |
+
def test_ConstantOfShape(self) -> None:
|
| 302 |
+
self._test_op_upgrade("ConstantOfShape", 9, [[3]])
|
| 303 |
+
|
| 304 |
+
def test_Conv_1(self) -> None:
|
| 305 |
+
self._test_op_upgrade(
|
| 306 |
+
"Conv", 1, [[1, 3, 5, 5], [4, 3, 2, 2], [4]], [[1, 4, 4, 4]]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def test_Conv_2(self) -> None:
|
| 310 |
+
self._test_op_upgrade(
|
| 311 |
+
"Conv", 1, [[1, 3, 5, 5], [4, 3, 2, 2], [4]], [[1, 4, 4, 4]]
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def test_Conv_3(self) -> None:
|
| 315 |
+
self._test_op_upgrade(
|
| 316 |
+
"Conv",
|
| 317 |
+
1,
|
| 318 |
+
[[1, 3, 5, 5], [4, 1, 2, 2], [4]],
|
| 319 |
+
[[1, 4, 3, 7]],
|
| 320 |
+
attrs={
|
| 321 |
+
"dilations": [1, 2],
|
| 322 |
+
"group": 3,
|
| 323 |
+
"pads": [0, 1, 2, 3],
|
| 324 |
+
"strides": [2, 1],
|
| 325 |
+
},
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def test_Convinteger(self) -> None:
|
| 329 |
+
self._test_op_upgrade(
|
| 330 |
+
"ConvInteger",
|
| 331 |
+
10,
|
| 332 |
+
[[1, 3, 5, 5], [4, 3, 2, 2], [4]],
|
| 333 |
+
[[1, 4, 4, 4]],
|
| 334 |
+
[TensorProto.UINT8, TensorProto.UINT8, TensorProto.UINT8],
|
| 335 |
+
[TensorProto.INT32],
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def test_ConvTranspose(self) -> None:
|
| 339 |
+
self._test_op_upgrade(
|
| 340 |
+
"ConvTranspose", 1, [[1, 1, 5, 5], [1, 1, 3, 3]], [[1, 1, 7, 7]]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def test_DeformConv(self) -> None:
|
| 344 |
+
self._test_op_upgrade(
|
| 345 |
+
"DeformConv",
|
| 346 |
+
19,
|
| 347 |
+
[[1, 1, 3, 3], [1, 1, 2, 2], [1, 8, 2, 2]],
|
| 348 |
+
[[1, 1, 2, 2]],
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
def test_Cosh(self) -> None:
|
| 352 |
+
self._test_op_upgrade("Cosh", 9)
|
| 353 |
+
|
| 354 |
+
def test_Cos(self) -> None:
|
| 355 |
+
self._test_op_upgrade("Cos", 7)
|
| 356 |
+
|
| 357 |
+
def test_Cumsum(self) -> None:
|
| 358 |
+
self._test_op_upgrade(
|
| 359 |
+
"CumSum",
|
| 360 |
+
11,
|
| 361 |
+
[[3, 4, 5], []],
|
| 362 |
+
[[3, 4, 5]],
|
| 363 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def test_DepthToSpace(self) -> None:
|
| 367 |
+
self._test_op_upgrade(
|
| 368 |
+
"DepthToSpace", 1, [[1, 8, 3, 3]], [[1, 2, 6, 6]], attrs={"blocksize": 2}
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def test_DequantizeLinear(self) -> None:
|
| 372 |
+
self._test_op_upgrade(
|
| 373 |
+
"DequantizeLinear",
|
| 374 |
+
10,
|
| 375 |
+
[[2, 3], [], []],
|
| 376 |
+
[[2, 3]],
|
| 377 |
+
[TensorProto.INT8, TensorProto.FLOAT, TensorProto.INT8],
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def test_Det_1(self) -> None:
|
| 381 |
+
self._test_op_upgrade("Det", 11, [[3, 5, 5]], [[3]])
|
| 382 |
+
|
| 383 |
+
def test_Det_2(self) -> None:
|
| 384 |
+
self._test_op_upgrade("Det", 11, [[5, 5]], [[]])
|
| 385 |
+
|
| 386 |
+
def test_DynamicQuantizeLinear(self) -> None:
|
| 387 |
+
self._test_op_upgrade(
|
| 388 |
+
"DynamicQuantizeLinear",
|
| 389 |
+
11,
|
| 390 |
+
[[3, 4, 5]],
|
| 391 |
+
[[3, 4, 5], [], []],
|
| 392 |
+
output_types=[TensorProto.UINT8, TensorProto.FLOAT, TensorProto.UINT8],
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def test_Div(self) -> None:
|
| 396 |
+
self._test_op_upgrade(
|
| 397 |
+
"Div", 1, [[3, 4, 5], [3, 1, 5]], attrs={"consumed_inputs": [0]}
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
def test_Dropout(self) -> None:
|
| 401 |
+
self._test_op_upgrade(
|
| 402 |
+
"Dropout", 1, attrs={"consumed_inputs": [0], "is_test": 1}
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def test_Einsum_1(self) -> None:
|
| 406 |
+
self._test_op_upgrade(
|
| 407 |
+
"Einsum",
|
| 408 |
+
12,
|
| 409 |
+
[[3, 4, 5], [3, 5, 6]],
|
| 410 |
+
[[3, 4, 6]],
|
| 411 |
+
attrs={"equation": "bij, bjk -> bik"},
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
def test_Einsum_2(self) -> None:
|
| 415 |
+
self._test_op_upgrade(
|
| 416 |
+
"Einsum", 12, [[4, 5]], [[5, 4]], attrs={"equation": "ij->ji"}
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
def test_Elu(self) -> None:
|
| 420 |
+
self._test_op_upgrade("Elu", 1, attrs={"consumed_inputs": [0]})
|
| 421 |
+
|
| 422 |
+
def test_Equal(self) -> None:
|
| 423 |
+
# 6->7 adapter is missing
|
| 424 |
+
self._test_op_upgrade(
|
| 425 |
+
"Equal", 7, [[2, 3], [2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def test_Erf(self) -> None:
|
| 429 |
+
self._test_op_upgrade("Erf", 9)
|
| 430 |
+
|
| 431 |
+
def test_Exp(self) -> None:
|
| 432 |
+
self._test_op_upgrade("Exp", 1, attrs={"consumed_inputs": [0]})
|
| 433 |
+
|
| 434 |
+
def test_Expand(self) -> None:
|
| 435 |
+
shape = helper.make_tensor(
|
| 436 |
+
"b", TensorProto.INT64, dims=[4], vals=np.array([5, 2, 6, 4])
|
| 437 |
+
)
|
| 438 |
+
self._test_op_upgrade(
|
| 439 |
+
"Expand",
|
| 440 |
+
8,
|
| 441 |
+
[[2, 1, 4], [4]],
|
| 442 |
+
[[5, 2, 6, 4]],
|
| 443 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 444 |
+
initializer=[shape],
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def test_EyeLike(self) -> None:
|
| 448 |
+
self._test_op_upgrade("EyeLike", 9, [[4, 5]], [[4, 5]])
|
| 449 |
+
|
| 450 |
+
def test_Flatten(self) -> None:
|
| 451 |
+
self._test_op_upgrade("Flatten", 1, [[3, 4, 5]], [[3, 20]], attrs={"axis": 1})
|
| 452 |
+
|
| 453 |
+
def test_Floor(self) -> None:
|
| 454 |
+
self._test_op_upgrade("Floor", 1, attrs={"consumed_inputs": [0]})
|
| 455 |
+
|
| 456 |
+
def test_Gather(self) -> None:
|
| 457 |
+
self._test_op_upgrade(
|
| 458 |
+
"Gather",
|
| 459 |
+
1,
|
| 460 |
+
[[3, 4, 5], [6, 7]],
|
| 461 |
+
[[6, 7, 4, 5]],
|
| 462 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def test_GatherElements(self) -> None:
|
| 466 |
+
self._test_op_upgrade(
|
| 467 |
+
"GatherElements",
|
| 468 |
+
11,
|
| 469 |
+
[[3, 4, 5], [6, 7]],
|
| 470 |
+
[[6, 7]],
|
| 471 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
def test_GatherND(self) -> None:
|
| 475 |
+
self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]])
|
| 476 |
+
|
| 477 |
+
def test_Gelu_approximate_tanh(self) -> None:
|
| 478 |
+
self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"})
|
| 479 |
+
|
| 480 |
+
def test_Gelu(self) -> None:
|
| 481 |
+
self._test_op_upgrade("Gelu", 20)
|
| 482 |
+
|
| 483 |
+
def test_Gemm(self) -> None:
|
| 484 |
+
self._test_op_upgrade("Gemm", 1, [[5, 4], [4, 3], [3]], [[5, 3]])
|
| 485 |
+
|
| 486 |
+
def test_GlobalAveragePool(self) -> None:
|
| 487 |
+
self._test_op_upgrade("GlobalAveragePool", 1, [[1, 3, 10, 10]], [[1, 3, 1, 1]])
|
| 488 |
+
|
| 489 |
+
def test_GlobalMaxPool(self) -> None:
|
| 490 |
+
self._test_op_upgrade("GlobalMaxPool", 1, [[1, 3, 10, 10]], [[1, 3, 1, 1]])
|
| 491 |
+
|
| 492 |
+
def test_GlobalLpPool(self) -> None:
|
| 493 |
+
# 1->2 adapter is missing
|
| 494 |
+
self._test_op_upgrade("GlobalLpPool", 2, [[1, 3, 10, 10]], [[1, 3, 1, 1]])
|
| 495 |
+
|
| 496 |
+
def test_Greater(self) -> None:
|
| 497 |
+
# 6->7 adapter is missing
|
| 498 |
+
self._test_op_upgrade(
|
| 499 |
+
"Greater", 7, [[2, 3], [2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def test_GreaterOrEqual(self) -> None:
|
| 503 |
+
self._test_op_upgrade(
|
| 504 |
+
"GreaterOrEqual",
|
| 505 |
+
12,
|
| 506 |
+
[[2, 3], [2, 3]],
|
| 507 |
+
[[2, 3]],
|
| 508 |
+
output_types=[TensorProto.BOOL],
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
def test_GridSample(self) -> None:
|
| 512 |
+
self._test_op_upgrade(
|
| 513 |
+
"GridSample",
|
| 514 |
+
16,
|
| 515 |
+
[[1, 1, 3, 3], [1, 3, 3, 2]],
|
| 516 |
+
[[1, 1, 3, 3]],
|
| 517 |
+
input_types=[TensorProto.FLOAT, TensorProto.FLOAT],
|
| 518 |
+
output_types=[TensorProto.FLOAT],
|
| 519 |
+
attrs={"mode": "nearest", "padding_mode": "border", "align_corners": 1},
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
def test_GRU_1(self) -> None:
|
| 523 |
+
# 2->3, 6->7 adapters are missing
|
| 524 |
+
self._test_op_upgrade(
|
| 525 |
+
"GRU",
|
| 526 |
+
7,
|
| 527 |
+
[[5, 3, 4], [1, 18, 4], [1, 18, 4]],
|
| 528 |
+
[[5, 1, 3, 6], [1, 3, 6]],
|
| 529 |
+
attrs={"hidden_size": 6},
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def test_GRU_2(self) -> None:
|
| 533 |
+
# 2->3, 6->7 adapters are missing
|
| 534 |
+
self._test_op_upgrade(
|
| 535 |
+
"GRU",
|
| 536 |
+
7,
|
| 537 |
+
[[5, 3, 4], [2, 18, 4], [2, 18, 4]],
|
| 538 |
+
[[5, 2, 3, 6], [2, 3, 6]],
|
| 539 |
+
attrs={"hidden_size": 6, "direction": "bidirectional"},
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
def test_GRU_3(self) -> None:
|
| 543 |
+
# 2->3, 6->7 adapters are missing
|
| 544 |
+
self._test_op_upgrade(
|
| 545 |
+
"GRU",
|
| 546 |
+
7,
|
| 547 |
+
[[5, 3, 4], [1, 18, 4], [1, 18, 4], [1, 24], [5], [1, 5, 6]],
|
| 548 |
+
[[5, 1, 3, 6], [1, 3, 6]],
|
| 549 |
+
[
|
| 550 |
+
TensorProto.FLOAT,
|
| 551 |
+
TensorProto.FLOAT,
|
| 552 |
+
TensorProto.FLOAT,
|
| 553 |
+
TensorProto.FLOAT,
|
| 554 |
+
TensorProto.INT64,
|
| 555 |
+
TensorProto.FLOAT,
|
| 556 |
+
],
|
| 557 |
+
attrs={"hidden_size": 6},
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
def test_HardSigmoid(self) -> None:
|
| 561 |
+
self._test_op_upgrade("HardSigmoid", 1, attrs={"consumed_inputs": [0]})
|
| 562 |
+
|
| 563 |
+
def test_HardSwish(self) -> None:
|
| 564 |
+
self._test_op_upgrade("HardSwish", 14)
|
| 565 |
+
|
| 566 |
+
def test_Hardmax(self) -> None:
|
| 567 |
+
self._test_op_upgrade("Hardmax", 1)
|
| 568 |
+
|
| 569 |
+
def test_Identity(self) -> None:
|
| 570 |
+
self._test_op_upgrade("Identity", 1)
|
| 571 |
+
|
| 572 |
+
def test_If(self) -> None:
|
| 573 |
+
sub_output = [
|
| 574 |
+
helper.make_tensor_value_info("out", TensorProto.FLOAT, [3, 4, 5])
|
| 575 |
+
]
|
| 576 |
+
then_tensor = helper.make_tensor(
|
| 577 |
+
"Value",
|
| 578 |
+
TensorProto.FLOAT,
|
| 579 |
+
dims=[3, 4, 5],
|
| 580 |
+
vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
|
| 581 |
+
raw=True,
|
| 582 |
+
)
|
| 583 |
+
then_node = helper.make_node("Constant", [], ["out"], value=then_tensor)
|
| 584 |
+
then_graph = helper.make_graph([then_node], "then_graph", [], sub_output, [])
|
| 585 |
+
else_tensor = helper.make_tensor(
|
| 586 |
+
"Value",
|
| 587 |
+
TensorProto.FLOAT,
|
| 588 |
+
dims=[3, 4, 5],
|
| 589 |
+
vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
|
| 590 |
+
raw=True,
|
| 591 |
+
)
|
| 592 |
+
else_node = helper.make_node("Constant", [], ["out"], value=else_tensor)
|
| 593 |
+
else_graph = helper.make_graph([else_node], "else_graph", [], sub_output, [])
|
| 594 |
+
self._test_op_upgrade(
|
| 595 |
+
"If",
|
| 596 |
+
1,
|
| 597 |
+
[[0]],
|
| 598 |
+
[[3, 4, 5]],
|
| 599 |
+
[TensorProto.BOOL],
|
| 600 |
+
attrs={"then_branch": then_graph, "else_branch": else_graph},
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
def test_ImageDecoder(self) -> None:
|
| 604 |
+
self._test_op_upgrade(
|
| 605 |
+
"ImageDecoder",
|
| 606 |
+
20,
|
| 607 |
+
[[None]],
|
| 608 |
+
[[None, None, 3]],
|
| 609 |
+
input_types=[TensorProto.UINT8],
|
| 610 |
+
output_types=[TensorProto.UINT8],
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
def test_InstanceNormalization(self) -> None:
|
| 614 |
+
self._test_op_upgrade(
|
| 615 |
+
"InstanceNormalization",
|
| 616 |
+
1,
|
| 617 |
+
[[1, 3], [3], [3]],
|
| 618 |
+
[[1, 3]],
|
| 619 |
+
attrs={"consumed_inputs": [0]},
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
def test_IsInf(self) -> None:
|
| 623 |
+
self._test_op_upgrade(
|
| 624 |
+
"IsInf", 10, [[2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def test_IsNaN(self) -> None:
|
| 628 |
+
self._test_op_upgrade(
|
| 629 |
+
"IsNaN", 9, [[2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
def test_LeakyRelu(self) -> None:
|
| 633 |
+
self._test_op_upgrade("LeakyRelu", 1, attrs={"consumed_inputs": [0]})
|
| 634 |
+
|
| 635 |
+
def test_Less(self) -> None:
|
| 636 |
+
# 6->7 adapter is missing
|
| 637 |
+
self._test_op_upgrade(
|
| 638 |
+
"Less", 7, [[2, 3], [2, 3]], [[2, 3]], output_types=[TensorProto.BOOL]
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
def test_LessOrEqual(self) -> None:
|
| 642 |
+
self._test_op_upgrade(
|
| 643 |
+
"LessOrEqual",
|
| 644 |
+
12,
|
| 645 |
+
[[2, 3], [2, 3]],
|
| 646 |
+
[[2, 3]],
|
| 647 |
+
output_types=[TensorProto.BOOL],
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
def test_Log(self) -> None:
|
| 651 |
+
self._test_op_upgrade("Log", 1, attrs={"consumed_inputs": [0]})
|
| 652 |
+
|
| 653 |
+
def test_LogSoftmax(self) -> None:
|
| 654 |
+
self._test_op_upgrade("LogSoftmax", 1)
|
| 655 |
+
|
| 656 |
+
def test_Loop_1(self) -> None:
|
| 657 |
+
iter_count = onnx.helper.make_tensor_value_info(
|
| 658 |
+
"iter_count", onnx.TensorProto.INT64, []
|
| 659 |
+
)
|
| 660 |
+
cond_in = onnx.helper.make_tensor_value_info(
|
| 661 |
+
"cond_in", onnx.TensorProto.BOOL, []
|
| 662 |
+
)
|
| 663 |
+
x_in = onnx.helper.make_tensor_value_info("x_in", onnx.TensorProto.FLOAT, [1])
|
| 664 |
+
cond_out = onnx.helper.make_tensor_value_info(
|
| 665 |
+
"cond_out", onnx.TensorProto.BOOL, []
|
| 666 |
+
)
|
| 667 |
+
x_out = onnx.helper.make_tensor_value_info("x_out", onnx.TensorProto.FLOAT, [1])
|
| 668 |
+
x_scan = onnx.helper.make_tensor_value_info(
|
| 669 |
+
"x_scan", onnx.TensorProto.FLOAT, [1]
|
| 670 |
+
)
|
| 671 |
+
const = onnx.helper.make_node(
|
| 672 |
+
"Constant",
|
| 673 |
+
inputs=[],
|
| 674 |
+
outputs=["one"],
|
| 675 |
+
value=onnx.helper.make_tensor(
|
| 676 |
+
name="value",
|
| 677 |
+
data_type=onnx.TensorProto.FLOAT,
|
| 678 |
+
dims=[1],
|
| 679 |
+
vals=np.array([1]).astype(np.float32).astype(float),
|
| 680 |
+
),
|
| 681 |
+
)
|
| 682 |
+
add = onnx.helper.make_node("Add", inputs=["x_in", "one"], outputs=["x_out"])
|
| 683 |
+
id_1 = onnx.helper.make_node("Identity", inputs=["x_out"], outputs=["x_scan"])
|
| 684 |
+
id_2 = onnx.helper.make_node(
|
| 685 |
+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
|
| 686 |
+
)
|
| 687 |
+
loop_body = onnx.helper.make_graph(
|
| 688 |
+
[const, add, id_1, id_2],
|
| 689 |
+
"loop_body",
|
| 690 |
+
[iter_count, cond_in, x_in],
|
| 691 |
+
[cond_out, x_out, x_scan],
|
| 692 |
+
)
|
| 693 |
+
self._test_op_upgrade(
|
| 694 |
+
"Loop",
|
| 695 |
+
1,
|
| 696 |
+
[[], "", [1]],
|
| 697 |
+
[[1], [5, 1]],
|
| 698 |
+
[TensorProto.INT64, TensorProto.BOOL, TensorProto.FLOAT],
|
| 699 |
+
attrs={"body": loop_body},
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
def test_Loop_2(self) -> None:
|
| 703 |
+
iter_count = onnx.helper.make_tensor_value_info(
|
| 704 |
+
"iter_count", onnx.TensorProto.INT64, []
|
| 705 |
+
)
|
| 706 |
+
cond_in = onnx.helper.make_tensor_value_info(
|
| 707 |
+
"cond_in", onnx.TensorProto.BOOL, []
|
| 708 |
+
)
|
| 709 |
+
x_in = onnx.helper.make_tensor_value_info(
|
| 710 |
+
"x_in", onnx.TensorProto.FLOAT, [2, 1]
|
| 711 |
+
)
|
| 712 |
+
cond_out = onnx.helper.make_tensor_value_info(
|
| 713 |
+
"cond_out", onnx.TensorProto.BOOL, []
|
| 714 |
+
)
|
| 715 |
+
x_out = onnx.helper.make_tensor_value_info(
|
| 716 |
+
"x_out", onnx.TensorProto.FLOAT, [2, 1]
|
| 717 |
+
)
|
| 718 |
+
squeeze = onnx.helper.make_node(
|
| 719 |
+
"Squeeze", inputs=["x_in"], outputs=["squeeze_out"], axes=[1]
|
| 720 |
+
)
|
| 721 |
+
unsqueeze = onnx.helper.make_node(
|
| 722 |
+
"Unsqueeze", inputs=["squeeze_out"], outputs=["x_out"], axes=[1]
|
| 723 |
+
)
|
| 724 |
+
identity = onnx.helper.make_node(
|
| 725 |
+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
|
| 726 |
+
)
|
| 727 |
+
loop_body = onnx.helper.make_graph(
|
| 728 |
+
[squeeze, unsqueeze, identity],
|
| 729 |
+
"loop_body",
|
| 730 |
+
[iter_count, cond_in, x_in],
|
| 731 |
+
[cond_out, x_out],
|
| 732 |
+
)
|
| 733 |
+
self._test_op_upgrade(
|
| 734 |
+
"Loop",
|
| 735 |
+
12,
|
| 736 |
+
[[], "", [2, 1]],
|
| 737 |
+
[[2, 1]],
|
| 738 |
+
[TensorProto.INT64, TensorProto.BOOL, TensorProto.FLOAT],
|
| 739 |
+
attrs={"body": loop_body},
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
def test_LpNormalization(self) -> None:
|
| 743 |
+
self._test_op_upgrade("LpNormalization", 1)
|
| 744 |
+
|
| 745 |
+
def test_LpPool(self) -> None:
|
| 746 |
+
# 1->2 adapter is missing
|
| 747 |
+
self._test_op_upgrade(
|
| 748 |
+
"LpPool", 2, [[1, 1, 5, 5]], [[1, 1, 4, 4]], attrs={"kernel_shape": [2, 2]}
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
def test_LRN_1(self) -> None:
|
| 752 |
+
self._test_op_upgrade("LRN", 1, attrs={"size": 3})
|
| 753 |
+
|
| 754 |
+
def test_LRN_2(self) -> None:
|
| 755 |
+
self._test_op_upgrade(
|
| 756 |
+
"LRN", 1, [[2, 3, 4, 5]], [[2, 3, 4, 5]], attrs={"size": 3}
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
def test_LSTM_1(self) -> None:
|
| 760 |
+
# 6->7 adapter is missing
|
| 761 |
+
self._test_op_upgrade(
|
| 762 |
+
"LSTM",
|
| 763 |
+
7,
|
| 764 |
+
[[5, 3, 4], [1, 24, 4], [1, 24, 4]],
|
| 765 |
+
[[5, 1, 3, 6], [1, 3, 6], [1, 3, 6]],
|
| 766 |
+
attrs={"hidden_size": 6},
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
def test_LSTM_2(self) -> None:
|
| 770 |
+
# 6->7 adapter is missing
|
| 771 |
+
self._test_op_upgrade(
|
| 772 |
+
"LSTM",
|
| 773 |
+
7,
|
| 774 |
+
[[5, 3, 4], [2, 24, 4], [2, 24, 4]],
|
| 775 |
+
[[5, 2, 3, 6], [2, 3, 6], [2, 3, 6]],
|
| 776 |
+
attrs={"hidden_size": 6, "direction": "bidirectional"},
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def test_LSTM_3(self) -> None:
|
| 780 |
+
# 6->7 adapter is missing
|
| 781 |
+
self._test_op_upgrade(
|
| 782 |
+
"LSTM",
|
| 783 |
+
7,
|
| 784 |
+
[
|
| 785 |
+
[5, 3, 4],
|
| 786 |
+
[1, 24, 4],
|
| 787 |
+
[1, 24, 4],
|
| 788 |
+
[1, 48],
|
| 789 |
+
[5],
|
| 790 |
+
[1, 5, 6],
|
| 791 |
+
[1, 5, 6],
|
| 792 |
+
[1, 18],
|
| 793 |
+
],
|
| 794 |
+
[[5, 1, 3, 6], [1, 3, 6], [1, 3, 6]],
|
| 795 |
+
[
|
| 796 |
+
TensorProto.FLOAT,
|
| 797 |
+
TensorProto.FLOAT,
|
| 798 |
+
TensorProto.FLOAT,
|
| 799 |
+
TensorProto.FLOAT,
|
| 800 |
+
TensorProto.INT64,
|
| 801 |
+
TensorProto.FLOAT,
|
| 802 |
+
TensorProto.FLOAT,
|
| 803 |
+
TensorProto.FLOAT,
|
| 804 |
+
],
|
| 805 |
+
attrs={"hidden_size": 6},
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
def test_MatMul_1(self) -> None:
|
| 809 |
+
self._test_op_upgrade("MatMul", 1, [[2, 3], [3, 4]], [[2, 4]])
|
| 810 |
+
|
| 811 |
+
def test_MatMul_2(self) -> None:
|
| 812 |
+
self._test_op_upgrade("MatMul", 1, [[5, 2, 3], [5, 3, 4]], [[5, 2, 4]])
|
| 813 |
+
|
| 814 |
+
def test_MatMulInteger_1(self) -> None:
|
| 815 |
+
self._test_op_upgrade(
|
| 816 |
+
"MatMulInteger",
|
| 817 |
+
10,
|
| 818 |
+
[[2, 3], [3, 4]],
|
| 819 |
+
[[2, 4]],
|
| 820 |
+
[TensorProto.INT8, TensorProto.INT8],
|
| 821 |
+
[TensorProto.INT32],
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
def test_MatMulInteger_2(self) -> None:
|
| 825 |
+
self._test_op_upgrade(
|
| 826 |
+
"MatMulInteger",
|
| 827 |
+
10,
|
| 828 |
+
[[2, 3], [3, 4], [], []],
|
| 829 |
+
[[2, 4]],
|
| 830 |
+
[TensorProto.INT8, TensorProto.INT8, TensorProto.INT8, TensorProto.INT8],
|
| 831 |
+
[TensorProto.INT32],
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
def test_MatMulInteger_3(self) -> None:
|
| 835 |
+
self._test_op_upgrade(
|
| 836 |
+
"MatMulInteger",
|
| 837 |
+
10,
|
| 838 |
+
[[2, 3], [3, 4], [2], [4]],
|
| 839 |
+
[[2, 4]],
|
| 840 |
+
[TensorProto.INT8, TensorProto.INT8, TensorProto.INT8, TensorProto.INT8],
|
| 841 |
+
[TensorProto.INT32],
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
def test_Max(self) -> None:
|
| 845 |
+
self._test_op_upgrade(
|
| 846 |
+
"Max",
|
| 847 |
+
1,
|
| 848 |
+
[[2, 3, 4], [2, 3, 4]],
|
| 849 |
+
[[2, 3, 4]],
|
| 850 |
+
attrs={"consumed_inputs": [0]},
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
def test_MaxPool_1(self) -> None:
|
| 854 |
+
self._test_op_upgrade(
|
| 855 |
+
"MaxPool", 1, [[1, 1, 5, 5]], [[1, 1, 4, 4]], attrs={"kernel_shape": [2, 2]}
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
def test_MaxPool_2(self) -> None:
|
| 859 |
+
self._test_op_upgrade(
|
| 860 |
+
"MaxPool",
|
| 861 |
+
8,
|
| 862 |
+
[[1, 1, 5, 5]],
|
| 863 |
+
[[1, 1, 4, 4], [1, 1, 4, 4]],
|
| 864 |
+
output_types=[TensorProto.FLOAT, TensorProto.INT64],
|
| 865 |
+
attrs={"kernel_shape": [2, 2]},
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
def test_MaxRoiPool(self) -> None:
|
| 869 |
+
self._test_op_upgrade(
|
| 870 |
+
"MaxRoiPool",
|
| 871 |
+
1,
|
| 872 |
+
[[2, 3, 20, 20], [4, 5]],
|
| 873 |
+
[[4, 3, 3, 3]],
|
| 874 |
+
attrs={"pooled_shape": [3, 3]},
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
def test_MaxUnpool(self) -> None:
|
| 878 |
+
self._test_op_upgrade(
|
| 879 |
+
"MaxUnpool",
|
| 880 |
+
9,
|
| 881 |
+
[[1, 1, 5, 5], [1, 1, 5, 5]],
|
| 882 |
+
[[1, 1, 6, 6]],
|
| 883 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 884 |
+
attrs={"kernel_shape": [2, 2]},
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
def test_Mean(self) -> None:
|
| 888 |
+
self._test_op_upgrade(
|
| 889 |
+
"Mean",
|
| 890 |
+
1,
|
| 891 |
+
[[2, 3, 4], [2, 3, 4]],
|
| 892 |
+
[[2, 3, 4]],
|
| 893 |
+
attrs={"consumed_inputs": [0]},
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
def test_MeanVarianceNormalization(self) -> None:
|
| 897 |
+
self._test_op_upgrade("MeanVarianceNormalization", 9, attrs={"axes": [1, 2]})
|
| 898 |
+
|
| 899 |
+
def test_Min(self) -> None:
|
| 900 |
+
self._test_op_upgrade(
|
| 901 |
+
"Min",
|
| 902 |
+
1,
|
| 903 |
+
[[2, 3, 4], [2, 3, 4]],
|
| 904 |
+
[[2, 3, 4]],
|
| 905 |
+
attrs={"consumed_inputs": [0]},
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
def test_Mish(self) -> None:
|
| 909 |
+
self._test_op_upgrade("Mish", 18)
|
| 910 |
+
|
| 911 |
+
def test_Mod_1(self) -> None:
|
| 912 |
+
self._test_op_upgrade("Mod", 10, [[2, 3], [2, 3]], [[2, 3]])
|
| 913 |
+
|
| 914 |
+
def test_Mod_2(self) -> None:
|
| 915 |
+
self._test_op_upgrade("Mod", 10, [[2, 3], [2, 3]], [[2, 3]], attrs={"fmod": 1})
|
| 916 |
+
|
| 917 |
+
def test_Mul(self) -> None:
|
| 918 |
+
self._test_op_upgrade(
|
| 919 |
+
"Mul",
|
| 920 |
+
1,
|
| 921 |
+
[[2, 3, 4], [2, 1, 4]],
|
| 922 |
+
[[2, 3, 4]],
|
| 923 |
+
attrs={"consumed_inputs": [0]},
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
def test_Multinomial(self) -> None:
|
| 927 |
+
self._test_op_upgrade(
|
| 928 |
+
"Multinomial",
|
| 929 |
+
7,
|
| 930 |
+
[[3, 5]],
|
| 931 |
+
[[3, 7]],
|
| 932 |
+
output_types=[TensorProto.INT32],
|
| 933 |
+
attrs={"sample_size": 7},
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
def test_Neg(self) -> None:
|
| 937 |
+
self._test_op_upgrade("Neg", 1, attrs={"consumed_inputs": [0]})
|
| 938 |
+
|
| 939 |
+
def test_NegativeLogLikelihoodLoss_1(self) -> None:
|
| 940 |
+
self._test_op_upgrade(
|
| 941 |
+
"NegativeLogLikelihoodLoss",
|
| 942 |
+
12,
|
| 943 |
+
[[3, 4, 5], [3, 5]],
|
| 944 |
+
[[]],
|
| 945 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
def test_NegativeLogLikelihoodLoss_2(self) -> None:
|
| 949 |
+
self._test_op_upgrade(
|
| 950 |
+
"NegativeLogLikelihoodLoss",
|
| 951 |
+
12,
|
| 952 |
+
[[3, 4, 5], [3, 5], [4]],
|
| 953 |
+
[[]],
|
| 954 |
+
[TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
def test_NonMaxSuppression(self) -> None:
|
| 958 |
+
self._test_op_upgrade(
|
| 959 |
+
"NonMaxSuppression",
|
| 960 |
+
10,
|
| 961 |
+
[[2, 3, 4], [3, 5, 6]],
|
| 962 |
+
[[2, 3]],
|
| 963 |
+
output_types=[TensorProto.INT64],
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
def test_NonZero(self) -> None:
|
| 967 |
+
self._test_op_upgrade(
|
| 968 |
+
"NonZero", 9, [[3, 3]], [[2, 4]], output_types=[TensorProto.INT64]
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
def test_Not(self) -> None:
|
| 972 |
+
self._test_op_upgrade(
|
| 973 |
+
"Not", 1, [[2, 3]], [[2, 3]], [TensorProto.BOOL], [TensorProto.BOOL]
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
def test_OneHot(self) -> None:
|
| 977 |
+
self._test_op_upgrade("OneHot", 9, [[3, 4, 5], [], [2]], [[3, 4, 5, 6]])
|
| 978 |
+
|
| 979 |
+
def test_Or(self) -> None:
|
| 980 |
+
# 6->7 adapter is missing
|
| 981 |
+
self._test_op_upgrade(
|
| 982 |
+
"Or",
|
| 983 |
+
7,
|
| 984 |
+
[[2, 3], [2, 3]],
|
| 985 |
+
[[2, 3]],
|
| 986 |
+
[TensorProto.BOOL, TensorProto.BOOL],
|
| 987 |
+
[TensorProto.BOOL],
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
def test_Pad(self) -> None:
|
| 991 |
+
# 1->2 adapter is missing
|
| 992 |
+
self._test_op_upgrade(
|
| 993 |
+
"Pad", 2, [[3, 4]], [[5, 8]], attrs={"pads": [1, 2, 1, 2], "value": 1.5}
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
def test_Pow(self) -> None:
|
| 997 |
+
self._test_op_upgrade("Pow", 1, [[2, 3, 4], [2, 3, 4]], [[2, 3, 4]])
|
| 998 |
+
|
| 999 |
+
def test_PRelu(self) -> None:
|
| 1000 |
+
self._test_op_upgrade(
|
| 1001 |
+
"PRelu",
|
| 1002 |
+
1,
|
| 1003 |
+
[[2, 3, 4], [2, 3, 4]],
|
| 1004 |
+
[[2, 3, 4]],
|
| 1005 |
+
attrs={"consumed_inputs": [0]},
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
def test_QLinearConv(self) -> None:
|
| 1009 |
+
self._test_op_upgrade(
|
| 1010 |
+
"QLinearConv",
|
| 1011 |
+
10,
|
| 1012 |
+
[[1, 3, 5, 5], [], [], [4, 3, 2, 2], [], [], [], []],
|
| 1013 |
+
[[1, 4, 4, 4]],
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
def test_QLinearMatMul(self) -> None:
|
| 1017 |
+
self._test_op_upgrade(
|
| 1018 |
+
"QLinearMatMul", 10, [[2, 3], [], [], [3, 4], [], [], [], []], [[2, 4]]
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
def test_QuantizeLinear(self) -> None:
|
| 1022 |
+
self._test_op_upgrade(
|
| 1023 |
+
"QuantizeLinear",
|
| 1024 |
+
10,
|
| 1025 |
+
[[3, 4, 5], [], []],
|
| 1026 |
+
[[3, 4, 5]],
|
| 1027 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.UINT8],
|
| 1028 |
+
[TensorProto.UINT8],
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
def test_RandomNormal(self) -> None:
|
| 1032 |
+
self._test_op_upgrade(
|
| 1033 |
+
"RandomNormal", 1, [], [[3, 4, 5]], attrs={"shape": [3, 4, 5]}
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
def test_RandomNormalLike(self) -> None:
|
| 1037 |
+
like = helper.make_tensor(
|
| 1038 |
+
"a",
|
| 1039 |
+
TensorProto.FLOAT,
|
| 1040 |
+
dims=[3, 4, 5],
|
| 1041 |
+
vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
|
| 1042 |
+
raw=True,
|
| 1043 |
+
)
|
| 1044 |
+
self._test_op_upgrade(
|
| 1045 |
+
"RandomNormalLike", 1, [[3, 4, 5]], [[3, 4, 5]], initializer=[like]
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
def test_RandomUniform(self) -> None:
|
| 1049 |
+
self._test_op_upgrade(
|
| 1050 |
+
"RandomUniform", 1, [], [[3, 4, 5]], attrs={"shape": [3, 4, 5]}
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
def test_RandomUniformLike(self) -> None:
|
| 1054 |
+
like = helper.make_tensor(
|
| 1055 |
+
"a",
|
| 1056 |
+
TensorProto.FLOAT,
|
| 1057 |
+
dims=[3, 4, 5],
|
| 1058 |
+
vals=np.random.rand(3, 4, 5).astype(np.float32).tobytes(),
|
| 1059 |
+
raw=True,
|
| 1060 |
+
)
|
| 1061 |
+
self._test_op_upgrade(
|
| 1062 |
+
"RandomUniformLike", 1, [[3, 4, 5]], [[3, 4, 5]], initializer=[like]
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
def test_Range(self) -> None:
|
| 1066 |
+
start = helper.make_tensor("a", TensorProto.FLOAT, dims=[], vals=np.array([0]))
|
| 1067 |
+
end = helper.make_tensor("b", TensorProto.FLOAT, dims=[], vals=np.array([12]))
|
| 1068 |
+
step = helper.make_tensor("c", TensorProto.FLOAT, dims=[], vals=np.array([2]))
|
| 1069 |
+
self._test_op_upgrade(
|
| 1070 |
+
"Range", 11, [[], [], []], [[6]], initializer=[start, end, step]
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
def test_Reciprocal(self) -> None:
|
| 1074 |
+
self._test_op_upgrade("Reciprocal", 1, attrs={"consumed_inputs": [0]})
|
| 1075 |
+
|
| 1076 |
+
def test_ReduceL1(self) -> None:
|
| 1077 |
+
self._test_op_upgrade("ReduceL1", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1078 |
+
|
| 1079 |
+
def test_ReduceL2(self) -> None:
|
| 1080 |
+
self._test_op_upgrade("ReduceL2", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1081 |
+
|
| 1082 |
+
def test_ReduceLogSum(self) -> None:
|
| 1083 |
+
self._test_op_upgrade("ReduceLogSum", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1084 |
+
|
| 1085 |
+
def test_ReduceLogSumExp(self) -> None:
|
| 1086 |
+
self._test_op_upgrade("ReduceLogSumExp", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1087 |
+
|
| 1088 |
+
def test_ReduceMean(self) -> None:
|
| 1089 |
+
self._test_op_upgrade("ReduceMean", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1090 |
+
|
| 1091 |
+
def test_ReduceMax(self) -> None:
|
| 1092 |
+
self._test_op_upgrade("ReduceMax", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1093 |
+
|
| 1094 |
+
def test_ReduceMin(self) -> None:
|
| 1095 |
+
self._test_op_upgrade("ReduceMin", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1096 |
+
|
| 1097 |
+
def test_ReduceProd(self) -> None:
|
| 1098 |
+
self._test_op_upgrade("ReduceProd", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1099 |
+
|
| 1100 |
+
def test_ReduceSum(self) -> None:
|
| 1101 |
+
self._test_op_upgrade("ReduceSum", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1102 |
+
|
| 1103 |
+
def test_ReduceSumSquare(self) -> None:
|
| 1104 |
+
self._test_op_upgrade("ReduceSumSquare", 1, [[3, 4, 5]], [[1, 1, 1]])
|
| 1105 |
+
|
| 1106 |
+
def test_Relu(self) -> None:
|
| 1107 |
+
self._test_op_upgrade("Relu", 1, attrs={"consumed_inputs": [0]})
|
| 1108 |
+
|
| 1109 |
+
def test_Reshape(self) -> None:
|
| 1110 |
+
self._test_op_upgrade(
|
| 1111 |
+
"Reshape",
|
| 1112 |
+
1,
|
| 1113 |
+
[[3, 4, 5]],
|
| 1114 |
+
[[3, 10, 2]],
|
| 1115 |
+
attrs={"consumed_inputs": [0], "shape": [3, 10, 2]},
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
def test_Resize(self) -> None:
|
| 1119 |
+
self._test_op_upgrade("Resize", 10, [[3, 4, 5], [3]], [[3, 8, 15]])
|
| 1120 |
+
|
| 1121 |
+
def test_ReverseSequence(self) -> None:
|
| 1122 |
+
self._test_op_upgrade(
|
| 1123 |
+
"ReverseSequence",
|
| 1124 |
+
10,
|
| 1125 |
+
[[3, 4, 5], [4]],
|
| 1126 |
+
[[3, 4, 5]],
|
| 1127 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
def test_RNN_1(self) -> None:
|
| 1131 |
+
# 6->7 adapter is missing
|
| 1132 |
+
self._test_op_upgrade(
|
| 1133 |
+
"RNN",
|
| 1134 |
+
7,
|
| 1135 |
+
[[5, 3, 4], [1, 6, 4], [1, 6, 4]],
|
| 1136 |
+
[[5, 1, 3, 6], [1, 3, 6]],
|
| 1137 |
+
attrs={"hidden_size": 6},
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
def test_RNN_2(self) -> None:
|
| 1141 |
+
# 6->7 adapter is missing
|
| 1142 |
+
self._test_op_upgrade(
|
| 1143 |
+
"RNN",
|
| 1144 |
+
7,
|
| 1145 |
+
[[5, 3, 4], [2, 6, 4], [2, 6, 4]],
|
| 1146 |
+
[[5, 2, 3, 6], [2, 3, 6]],
|
| 1147 |
+
attrs={"hidden_size": 6, "direction": "bidirectional"},
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
def test_RNN_3(self) -> None:
|
| 1151 |
+
# 6->7 adapter is missing
|
| 1152 |
+
self._test_op_upgrade(
|
| 1153 |
+
"RNN",
|
| 1154 |
+
7,
|
| 1155 |
+
[[5, 3, 4], [1, 6, 4], [1, 6, 4], [1, 12], [5], [1, 5, 6]],
|
| 1156 |
+
[[5, 1, 3, 6], [1, 3, 6]],
|
| 1157 |
+
[
|
| 1158 |
+
TensorProto.FLOAT,
|
| 1159 |
+
TensorProto.FLOAT,
|
| 1160 |
+
TensorProto.FLOAT,
|
| 1161 |
+
TensorProto.FLOAT,
|
| 1162 |
+
TensorProto.INT64,
|
| 1163 |
+
TensorProto.FLOAT,
|
| 1164 |
+
],
|
| 1165 |
+
attrs={"hidden_size": 6},
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
def test_RoiAlign_1(self) -> None:
|
| 1169 |
+
self._test_op_upgrade(
|
| 1170 |
+
"RoiAlign",
|
| 1171 |
+
10,
|
| 1172 |
+
[[2, 3, 20, 20], [10, 4], [10]],
|
| 1173 |
+
[[10, 3, 1, 1]],
|
| 1174 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.INT64],
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
def test_RoiAlign_2(self) -> None:
|
| 1178 |
+
self._test_op_upgrade(
|
| 1179 |
+
"RoiAlign",
|
| 1180 |
+
16,
|
| 1181 |
+
[[2, 3, 20, 20], [10, 4], [10]],
|
| 1182 |
+
[[10, 3, 1, 1]],
|
| 1183 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.INT64],
|
| 1184 |
+
attrs={"coordinate_transformation_mode": "half_pixel"},
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
def test_RotaryEmbedding_1(self) -> None:
|
| 1188 |
+
self._test_op_upgrade(
|
| 1189 |
+
"RotaryEmbedding",
|
| 1190 |
+
23,
|
| 1191 |
+
[[2, 4, 3, 8], [2, 3, 4], [2, 3, 4]],
|
| 1192 |
+
[[2, 4, 3, 8]],
|
| 1193 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 1194 |
+
[TensorProto.FLOAT],
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
def test_RotaryEmbedding_2(self) -> None:
|
| 1198 |
+
self._test_op_upgrade(
|
| 1199 |
+
"RotaryEmbedding",
|
| 1200 |
+
23,
|
| 1201 |
+
[[2, 4, 3, 8], [50, 4], [50, 4], [2, 3]],
|
| 1202 |
+
[[2, 4, 3, 8]],
|
| 1203 |
+
[
|
| 1204 |
+
TensorProto.FLOAT,
|
| 1205 |
+
TensorProto.FLOAT,
|
| 1206 |
+
TensorProto.FLOAT,
|
| 1207 |
+
TensorProto.INT64,
|
| 1208 |
+
],
|
| 1209 |
+
[TensorProto.FLOAT],
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
def test_RotaryEmbedding_3(self) -> None:
|
| 1213 |
+
self._test_op_upgrade(
|
| 1214 |
+
"RotaryEmbedding",
|
| 1215 |
+
23,
|
| 1216 |
+
[[2, 3, 32], [50, 4], [50, 4], [2, 3]],
|
| 1217 |
+
[[2, 3, 32]],
|
| 1218 |
+
[
|
| 1219 |
+
TensorProto.FLOAT,
|
| 1220 |
+
TensorProto.FLOAT,
|
| 1221 |
+
TensorProto.FLOAT,
|
| 1222 |
+
TensorProto.INT64,
|
| 1223 |
+
],
|
| 1224 |
+
[TensorProto.FLOAT],
|
| 1225 |
+
attrs={"num_heads": 4},
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
def test_RotaryEmbedding_4(self) -> None:
|
| 1229 |
+
self._test_op_upgrade(
|
| 1230 |
+
"RotaryEmbedding",
|
| 1231 |
+
23,
|
| 1232 |
+
[[2, 4, 3, 8], [50, 4], [50, 4], [2, 3]],
|
| 1233 |
+
[[2, 4, 3, 8]],
|
| 1234 |
+
[
|
| 1235 |
+
TensorProto.FLOAT,
|
| 1236 |
+
TensorProto.FLOAT,
|
| 1237 |
+
TensorProto.FLOAT,
|
| 1238 |
+
TensorProto.INT64,
|
| 1239 |
+
],
|
| 1240 |
+
[TensorProto.FLOAT],
|
| 1241 |
+
attrs={"interleaved": 1},
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
def test_RotaryEmbedding_5(self) -> None:
|
| 1245 |
+
self._test_op_upgrade(
|
| 1246 |
+
"RotaryEmbedding",
|
| 1247 |
+
23,
|
| 1248 |
+
[[2, 4, 3, 8], [50, 4], [50, 4], [2, 3]],
|
| 1249 |
+
[[2, 4, 3, 8]],
|
| 1250 |
+
[
|
| 1251 |
+
TensorProto.FLOAT,
|
| 1252 |
+
TensorProto.FLOAT,
|
| 1253 |
+
TensorProto.FLOAT,
|
| 1254 |
+
TensorProto.INT64,
|
| 1255 |
+
],
|
| 1256 |
+
[TensorProto.FLOAT],
|
| 1257 |
+
attrs={"rotary_embedding_dim": 4},
|
| 1258 |
+
)
|
| 1259 |
+
|
| 1260 |
+
def test_Round(self) -> None:
|
| 1261 |
+
self._test_op_upgrade("Round", 11)
|
| 1262 |
+
|
| 1263 |
+
def test_RMSNormalization(self) -> None:
|
| 1264 |
+
self._test_op_upgrade(
|
| 1265 |
+
"RMSNormalization",
|
| 1266 |
+
23,
|
| 1267 |
+
[[2, 3, 4, 5], [4, 5]],
|
| 1268 |
+
[[2, 3, 4, 5]],
|
| 1269 |
+
input_types=[TensorProto.FLOAT, TensorProto.FLOAT],
|
| 1270 |
+
output_types=[TensorProto.FLOAT],
|
| 1271 |
+
attrs={"axis": 2},
|
| 1272 |
+
)
|
| 1273 |
+
|
| 1274 |
+
def test_Scatter(self) -> None:
|
| 1275 |
+
self._test_op_upgrade(
|
| 1276 |
+
"Scatter",
|
| 1277 |
+
9,
|
| 1278 |
+
[[2, 3], [1, 2], [1, 2]],
|
| 1279 |
+
[[2, 3]],
|
| 1280 |
+
[TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
|
| 1281 |
+
[TensorProto.FLOAT],
|
| 1282 |
+
)
|
| 1283 |
+
|
| 1284 |
+
def test_ScatterElements_1(self) -> None:
|
| 1285 |
+
self._test_op_upgrade(
|
| 1286 |
+
"ScatterElements",
|
| 1287 |
+
11,
|
| 1288 |
+
[[2, 3], [1, 2], [1, 2]],
|
| 1289 |
+
[[2, 3]],
|
| 1290 |
+
[TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
|
| 1291 |
+
[TensorProto.FLOAT],
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
def test_ScatterElements_2(self) -> None:
|
| 1295 |
+
self._test_op_upgrade(
|
| 1296 |
+
"ScatterElements",
|
| 1297 |
+
16,
|
| 1298 |
+
[[2, 3], [1, 2], [1, 2]],
|
| 1299 |
+
[[2, 3]],
|
| 1300 |
+
[TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
|
| 1301 |
+
[TensorProto.FLOAT],
|
| 1302 |
+
attrs={"reduction": "add"},
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
def test_ScatterND_1(self) -> None:
|
| 1306 |
+
self._test_op_upgrade(
|
| 1307 |
+
"ScatterND",
|
| 1308 |
+
11,
|
| 1309 |
+
[[2, 3], [1, 2], [1, 2]],
|
| 1310 |
+
[[2, 3]],
|
| 1311 |
+
[TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
|
| 1312 |
+
[TensorProto.FLOAT],
|
| 1313 |
+
)
|
| 1314 |
+
|
| 1315 |
+
def test_ScatterND_2(self) -> None:
|
| 1316 |
+
self._test_op_upgrade(
|
| 1317 |
+
"ScatterND",
|
| 1318 |
+
16,
|
| 1319 |
+
[[2, 3], [1, 2], [1, 2]],
|
| 1320 |
+
[[2, 3]],
|
| 1321 |
+
[TensorProto.FLOAT, TensorProto.INT64, TensorProto.FLOAT],
|
| 1322 |
+
[TensorProto.FLOAT],
|
| 1323 |
+
attrs={"reduction": "mul"},
|
| 1324 |
+
)
|
| 1325 |
+
|
| 1326 |
+
def test_Scan(self) -> None:
|
| 1327 |
+
sum_in = onnx.helper.make_tensor_value_info(
|
| 1328 |
+
"sum_in", onnx.TensorProto.FLOAT, [2]
|
| 1329 |
+
)
|
| 1330 |
+
next_in = onnx.helper.make_tensor_value_info(
|
| 1331 |
+
"next_in", onnx.TensorProto.FLOAT, [2]
|
| 1332 |
+
)
|
| 1333 |
+
sum_out = onnx.helper.make_tensor_value_info(
|
| 1334 |
+
"sum_out", onnx.TensorProto.FLOAT, [2]
|
| 1335 |
+
)
|
| 1336 |
+
scan_out = onnx.helper.make_tensor_value_info(
|
| 1337 |
+
"scan_out", onnx.TensorProto.FLOAT, [2]
|
| 1338 |
+
)
|
| 1339 |
+
add_node = onnx.helper.make_node(
|
| 1340 |
+
"Add", inputs=["sum_in", "next_in"], outputs=["sum_out"]
|
| 1341 |
+
)
|
| 1342 |
+
id_node = onnx.helper.make_node(
|
| 1343 |
+
"Identity", inputs=["sum_out"], outputs=["scan_out"]
|
| 1344 |
+
)
|
| 1345 |
+
body = onnx.helper.make_graph(
|
| 1346 |
+
[add_node, id_node], "scan_body", [sum_in, next_in], [sum_out, scan_out]
|
| 1347 |
+
)
|
| 1348 |
+
self._test_op_upgrade(
|
| 1349 |
+
"Scan",
|
| 1350 |
+
8,
|
| 1351 |
+
["", [1, 2], [1, 3, 2]],
|
| 1352 |
+
[[1, 2], [1, 3, 2]],
|
| 1353 |
+
attrs={"body": body, "num_scan_inputs": 1},
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
def test_Selu(self) -> None:
|
| 1357 |
+
self._test_op_upgrade("Selu", 1, attrs={"consumed_inputs": [0]})
|
| 1358 |
+
|
| 1359 |
+
def test_Shape(self) -> None:
|
| 1360 |
+
self._test_op_upgrade(
|
| 1361 |
+
"Shape", 1, [[3, 4, 5]], [[3]], output_types=[TensorProto.INT64]
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
def test_Shrink(self) -> None:
|
| 1365 |
+
self._test_op_upgrade("Shrink", 9)
|
| 1366 |
+
|
| 1367 |
+
def test_Sigmoid(self) -> None:
|
| 1368 |
+
self._test_op_upgrade("Sigmoid", 1, attrs={"consumed_inputs": [0]})
|
| 1369 |
+
|
| 1370 |
+
def test_Sign(self) -> None:
|
| 1371 |
+
self._test_op_upgrade("Sign", 9)
|
| 1372 |
+
|
| 1373 |
+
def test_Sinh(self) -> None:
|
| 1374 |
+
self._test_op_upgrade("Sinh", 9)
|
| 1375 |
+
|
| 1376 |
+
def test_Sin(self) -> None:
|
| 1377 |
+
self._test_op_upgrade("Sin", 7)
|
| 1378 |
+
|
| 1379 |
+
def test_Size(self) -> None:
|
| 1380 |
+
self._test_op_upgrade(
|
| 1381 |
+
"Size", 1, [[3, 4, 5]], [[]], output_types=[TensorProto.INT64]
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
def test_Slice(self) -> None:
|
| 1385 |
+
self._test_op_upgrade(
|
| 1386 |
+
"Slice",
|
| 1387 |
+
1,
|
| 1388 |
+
[[3, 4, 5]],
|
| 1389 |
+
[[3, 2, 2]],
|
| 1390 |
+
attrs={"axes": [1, 2], "starts": [0, 1], "ends": [2, 3]},
|
| 1391 |
+
)
|
| 1392 |
+
|
| 1393 |
+
def test_Softmax_0(self) -> None:
|
| 1394 |
+
self._test_op_upgrade("Softmax", 1, attrs={"axis": 0})
|
| 1395 |
+
|
| 1396 |
+
def test_Softmax_1(self) -> None:
|
| 1397 |
+
self._test_op_upgrade("Softmax", 1, attrs={"axis": 1})
|
| 1398 |
+
|
| 1399 |
+
def test_Softmax_2(self) -> None:
|
| 1400 |
+
self._test_op_upgrade("Softmax", 1, attrs={"axis": 2})
|
| 1401 |
+
|
| 1402 |
+
def test_Softmax_3(self) -> None:
|
| 1403 |
+
self._test_op_upgrade("Softmax", 1, attrs={"axis": -1})
|
| 1404 |
+
|
| 1405 |
+
def test_Softmax_4(self) -> None:
|
| 1406 |
+
self._test_op_upgrade("Softmax", 1, attrs={"axis": -2})
|
| 1407 |
+
|
| 1408 |
+
def test_Softmax_5(self) -> None:
|
| 1409 |
+
self._test_op_upgrade("Softmax", 1, attrs={"axis": -3})
|
| 1410 |
+
|
| 1411 |
+
def test_Softplus(self) -> None:
|
| 1412 |
+
self._test_op_upgrade("Softplus", 1)
|
| 1413 |
+
|
| 1414 |
+
def test_Softsign(self) -> None:
|
| 1415 |
+
self._test_op_upgrade("Softsign", 1)
|
| 1416 |
+
|
| 1417 |
+
def test_SoftmaxCrossEntropyLoss(self) -> None:
|
| 1418 |
+
self._test_op_upgrade(
|
| 1419 |
+
"SoftmaxCrossEntropyLoss",
|
| 1420 |
+
12,
|
| 1421 |
+
[[3, 4, 5, 6], [3, 6]],
|
| 1422 |
+
[[]],
|
| 1423 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
def test_SpaceToDepth(self) -> None:
|
| 1427 |
+
self._test_op_upgrade(
|
| 1428 |
+
"SpaceToDepth", 1, [[1, 3, 8, 8]], [[1, 12, 4, 4]], attrs={"blocksize": 2}
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
def test_Split(self) -> None:
|
| 1432 |
+
# 1->2 adapter is missing
|
| 1433 |
+
self._test_op_upgrade(
|
| 1434 |
+
"Split",
|
| 1435 |
+
2,
|
| 1436 |
+
[[3, 4, 7]],
|
| 1437 |
+
[[3, 4, 2], [3, 4, 1], [3, 4, 4]],
|
| 1438 |
+
attrs={"axis": 2, "split": [2, 1, 4]},
|
| 1439 |
+
)
|
| 1440 |
+
|
| 1441 |
+
def test_Sqrt(self) -> None:
|
| 1442 |
+
self._test_op_upgrade("Sqrt", 1, attrs={"consumed_inputs": [0]})
|
| 1443 |
+
|
| 1444 |
+
def test_Squeeze(self) -> None:
|
| 1445 |
+
self._test_op_upgrade("Squeeze", 1, [[2, 1, 3, 4, 1]], [[2, 3, 4]])
|
| 1446 |
+
|
| 1447 |
+
def test_StringNormalizer(self) -> None:
|
| 1448 |
+
self._test_op_upgrade(
|
| 1449 |
+
"StringNormalizer",
|
| 1450 |
+
10,
|
| 1451 |
+
[[1, 3]],
|
| 1452 |
+
[[1, 3]],
|
| 1453 |
+
[TensorProto.STRING],
|
| 1454 |
+
[TensorProto.STRING],
|
| 1455 |
+
attrs={"case_change_action": "LOWER"},
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
def test_Sub(self) -> None:
|
| 1459 |
+
self._test_op_upgrade(
|
| 1460 |
+
"Sub",
|
| 1461 |
+
1,
|
| 1462 |
+
[[2, 3, 4], [2, 3, 4]],
|
| 1463 |
+
[[2, 3, 4]],
|
| 1464 |
+
attrs={"consumed_inputs": [0]},
|
| 1465 |
+
)
|
| 1466 |
+
|
| 1467 |
+
def test_Sum(self) -> None:
|
| 1468 |
+
self._test_op_upgrade(
|
| 1469 |
+
"Sum",
|
| 1470 |
+
1,
|
| 1471 |
+
[[2, 3, 4], [2, 3, 4]],
|
| 1472 |
+
[[2, 3, 4]],
|
| 1473 |
+
attrs={"consumed_inputs": [0]},
|
| 1474 |
+
)
|
| 1475 |
+
|
| 1476 |
+
def test_Swish(self) -> None:
|
| 1477 |
+
self._test_op_upgrade("Swish", 24, attrs={"alpha": 0.2})
|
| 1478 |
+
|
| 1479 |
+
def test_Tanh(self) -> None:
|
| 1480 |
+
self._test_op_upgrade("Tanh", 1, attrs={"consumed_inputs": [0]})
|
| 1481 |
+
|
| 1482 |
+
def test_Tan(self) -> None:
|
| 1483 |
+
self._test_op_upgrade("Tan", 7)
|
| 1484 |
+
|
| 1485 |
+
def test_TfIdfVectorizer(self) -> None:
|
| 1486 |
+
self._test_op_upgrade(
|
| 1487 |
+
"TfIdfVectorizer",
|
| 1488 |
+
9,
|
| 1489 |
+
[[3]],
|
| 1490 |
+
[[5]],
|
| 1491 |
+
attrs={
|
| 1492 |
+
"max_gram_length": 3,
|
| 1493 |
+
"max_skip_count": 1,
|
| 1494 |
+
"min_gram_length": 2,
|
| 1495 |
+
"mode": "TFIDF",
|
| 1496 |
+
"ngram_counts": [0, 20],
|
| 1497 |
+
"ngram_indexes": [3, 4],
|
| 1498 |
+
},
|
| 1499 |
+
)
|
| 1500 |
+
|
| 1501 |
+
def test_ThresholdedRelu(self) -> None:
|
| 1502 |
+
self._test_op_upgrade("ThresholdedRelu", 10)
|
| 1503 |
+
|
| 1504 |
+
def test_Tile(self) -> None:
|
| 1505 |
+
# 5->6 adapter is missing
|
| 1506 |
+
repeats = helper.make_tensor(
|
| 1507 |
+
"b", TensorProto.INT64, dims=[3], vals=np.array([1, 2, 3])
|
| 1508 |
+
)
|
| 1509 |
+
self._test_op_upgrade(
|
| 1510 |
+
"Tile",
|
| 1511 |
+
6,
|
| 1512 |
+
[[3, 4, 5], [3]],
|
| 1513 |
+
[[3, 8, 15]],
|
| 1514 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 1515 |
+
initializer=[repeats],
|
| 1516 |
+
)
|
| 1517 |
+
|
| 1518 |
+
def test_TopK(self) -> None:
|
| 1519 |
+
self._test_op_upgrade(
|
| 1520 |
+
"TopK",
|
| 1521 |
+
1,
|
| 1522 |
+
[[3, 4, 5]],
|
| 1523 |
+
[[3, 4, 2], [3, 4, 2]],
|
| 1524 |
+
output_types=[TensorProto.FLOAT, TensorProto.INT64],
|
| 1525 |
+
attrs={"k": 2},
|
| 1526 |
+
)
|
| 1527 |
+
|
| 1528 |
+
def test_Transpose(self) -> None:
|
| 1529 |
+
self._test_op_upgrade(
|
| 1530 |
+
"Transpose",
|
| 1531 |
+
1,
|
| 1532 |
+
[[1, 2, 5, 3, 7]],
|
| 1533 |
+
[[1, 7, 5, 2, 3]],
|
| 1534 |
+
attrs={"perm": [0, 4, 2, 1, 3]},
|
| 1535 |
+
)
|
| 1536 |
+
|
| 1537 |
+
def test_Trilu(self) -> None:
|
| 1538 |
+
self._test_op_upgrade("Trilu", 14)
|
| 1539 |
+
|
| 1540 |
+
def test_Unique_1(self) -> None:
|
| 1541 |
+
self._test_op_upgrade("Unique", 11, [[3, 4, 5]], [[None]])
|
| 1542 |
+
|
| 1543 |
+
def test_Unique_2(self) -> None:
|
| 1544 |
+
self._test_op_upgrade(
|
| 1545 |
+
"Unique", 11, [[3, 4, 5]], [[3, None, 5]], attrs={"axis": 1}
|
| 1546 |
+
)
|
| 1547 |
+
|
| 1548 |
+
def test_Unsqueeze(self) -> None:
|
| 1549 |
+
self._test_op_upgrade(
|
| 1550 |
+
"Unsqueeze", 1, [[3, 4, 5]], [[3, 4, 1, 5]], attrs={"axes": [2]}
|
| 1551 |
+
)
|
| 1552 |
+
|
| 1553 |
+
def test_Upsample(self) -> None:
|
| 1554 |
+
self._test_op_upgrade(
|
| 1555 |
+
"Upsample",
|
| 1556 |
+
1,
|
| 1557 |
+
[[1, 3, 4, 5]],
|
| 1558 |
+
[[1, 3, 6, 10]],
|
| 1559 |
+
attrs={"width_scale": 2.0, "height_scale": 1.5},
|
| 1560 |
+
)
|
| 1561 |
+
|
| 1562 |
+
def test_Where(self) -> None:
|
| 1563 |
+
self._test_op_upgrade(
|
| 1564 |
+
"Where",
|
| 1565 |
+
9,
|
| 1566 |
+
[[2, 3], [2, 3], [2, 3]],
|
| 1567 |
+
[[2, 3]],
|
| 1568 |
+
[TensorProto.BOOL, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 1569 |
+
)
|
| 1570 |
+
|
| 1571 |
+
def test_Xor(self) -> None:
|
| 1572 |
+
# 6->7 adapter is missing
|
| 1573 |
+
self._test_op_upgrade(
|
| 1574 |
+
"Xor",
|
| 1575 |
+
7,
|
| 1576 |
+
[[2, 3], [2, 3]],
|
| 1577 |
+
[[2, 3]],
|
| 1578 |
+
[TensorProto.BOOL, TensorProto.BOOL],
|
| 1579 |
+
[TensorProto.BOOL],
|
| 1580 |
+
)
|
| 1581 |
+
|
| 1582 |
+
def test_CastLike(self) -> None:
|
| 1583 |
+
self._test_op_upgrade(
|
| 1584 |
+
"CastLike",
|
| 1585 |
+
15,
|
| 1586 |
+
[[2, 3, 4], [2, 1, 4]],
|
| 1587 |
+
[[2, 3, 4]],
|
| 1588 |
+
input_types=[TensorProto.FLOAT, TensorProto.FLOAT16],
|
| 1589 |
+
output_types=[TensorProto.FLOAT16],
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
def test_LayerNormalization(self) -> None:
|
| 1593 |
+
self._test_op_upgrade(
|
| 1594 |
+
"LayerNormalization",
|
| 1595 |
+
17,
|
| 1596 |
+
[[2, 3, 4, 5], [4, 5], [4, 5]],
|
| 1597 |
+
[[2, 3, 4, 5]],
|
| 1598 |
+
input_types=[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.FLOAT],
|
| 1599 |
+
output_types=[TensorProto.FLOAT],
|
| 1600 |
+
attrs={"axis": 2},
|
| 1601 |
+
)
|
| 1602 |
+
|
| 1603 |
+
def _test_window_function(self, window_function_name: str) -> None:
|
| 1604 |
+
size = helper.make_tensor("a", TensorProto.INT64, dims=[], vals=np.array([10]))
|
| 1605 |
+
self._test_op_upgrade(
|
| 1606 |
+
window_function_name,
|
| 1607 |
+
17,
|
| 1608 |
+
[[]],
|
| 1609 |
+
[[10]],
|
| 1610 |
+
[TensorProto.INT64],
|
| 1611 |
+
initializer=[size],
|
| 1612 |
+
)
|
| 1613 |
+
|
| 1614 |
+
def test_BlackmanWindow(self) -> None:
|
| 1615 |
+
self._test_window_function("BlackmanWindow")
|
| 1616 |
+
|
| 1617 |
+
def test_HannWindow(self) -> None:
|
| 1618 |
+
self._test_window_function("HannWindow")
|
| 1619 |
+
|
| 1620 |
+
def test_HammingWindow(self) -> None:
|
| 1621 |
+
self._test_window_function("HammingWindow")
|
| 1622 |
+
|
| 1623 |
+
def test_DFT(self) -> None:
|
| 1624 |
+
self._test_op_upgrade("DFT", 17, [[2, 16, 1], []], [[2, 16, 2]])
|
| 1625 |
+
self._test_op_upgrade("DFT", 17, [[2, 16, 2], []], [[2, 16, 2]])
|
| 1626 |
+
self._test_op_upgrade(
|
| 1627 |
+
"DFT", 17, [[2, 16, 1], []], [[2, 9, 2]], attrs={"onesided": 1}
|
| 1628 |
+
)
|
| 1629 |
+
self._test_op_upgrade(
|
| 1630 |
+
"DFT", 17, [[2, 16, 2], []], [[2, 9, 2]], attrs={"onesided": 1}
|
| 1631 |
+
)
|
| 1632 |
+
self._test_op_upgrade(
|
| 1633 |
+
"DFT", 17, [[2, 16, 1], []], [[2, 16, 2]], attrs={"inverse": 1}
|
| 1634 |
+
)
|
| 1635 |
+
self._test_op_upgrade(
|
| 1636 |
+
"DFT", 17, [[2, 16, 2], []], [[2, 16, 2]], attrs={"inverse": 1}
|
| 1637 |
+
)
|
| 1638 |
+
self._test_op_upgrade(
|
| 1639 |
+
"DFT", 17, [[2, 16, 2], []], [[2, 16, 2]], attrs={"inverse": 1, "axis": 0}
|
| 1640 |
+
)
|
| 1641 |
+
|
| 1642 |
+
def _test_short_time_fourier_transform(self, operator_name: str) -> None:
|
| 1643 |
+
# Real
|
| 1644 |
+
signal = helper.make_tensor(
|
| 1645 |
+
"a",
|
| 1646 |
+
TensorProto.FLOAT,
|
| 1647 |
+
dims=[2, 64],
|
| 1648 |
+
vals=np.random.rand(2, 64).astype(np.float32),
|
| 1649 |
+
)
|
| 1650 |
+
frame_step = helper.make_tensor(
|
| 1651 |
+
"b", TensorProto.INT64, dims=[1], vals=np.array([8])
|
| 1652 |
+
)
|
| 1653 |
+
window = helper.make_tensor(
|
| 1654 |
+
"c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
|
| 1655 |
+
)
|
| 1656 |
+
self._test_op_upgrade(
|
| 1657 |
+
operator_name,
|
| 1658 |
+
17,
|
| 1659 |
+
[[2, 64], [1], [16]],
|
| 1660 |
+
[[2, 7, 16, 2]],
|
| 1661 |
+
[
|
| 1662 |
+
TensorProto.FLOAT,
|
| 1663 |
+
TensorProto.INT64,
|
| 1664 |
+
TensorProto.FLOAT,
|
| 1665 |
+
TensorProto.INT64,
|
| 1666 |
+
],
|
| 1667 |
+
initializer=[signal, frame_step, window],
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
# Real Onesided
|
| 1671 |
+
signal = helper.make_tensor(
|
| 1672 |
+
"a",
|
| 1673 |
+
TensorProto.FLOAT,
|
| 1674 |
+
dims=[2, 64],
|
| 1675 |
+
vals=np.random.rand(2, 64).astype(np.float32),
|
| 1676 |
+
)
|
| 1677 |
+
frame_step = helper.make_tensor(
|
| 1678 |
+
"b", TensorProto.INT64, dims=[1], vals=np.array([8])
|
| 1679 |
+
)
|
| 1680 |
+
window = helper.make_tensor(
|
| 1681 |
+
"c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
|
| 1682 |
+
)
|
| 1683 |
+
self._test_op_upgrade(
|
| 1684 |
+
operator_name,
|
| 1685 |
+
17,
|
| 1686 |
+
[[2, 64], [1], [16]],
|
| 1687 |
+
[[2, 7, 9, 2]],
|
| 1688 |
+
[
|
| 1689 |
+
TensorProto.FLOAT,
|
| 1690 |
+
TensorProto.INT64,
|
| 1691 |
+
TensorProto.FLOAT,
|
| 1692 |
+
TensorProto.INT64,
|
| 1693 |
+
],
|
| 1694 |
+
attrs={"onesided": 1},
|
| 1695 |
+
initializer=[signal, frame_step, window],
|
| 1696 |
+
)
|
| 1697 |
+
|
| 1698 |
+
# Complex
|
| 1699 |
+
signal = helper.make_tensor(
|
| 1700 |
+
"a",
|
| 1701 |
+
TensorProto.FLOAT,
|
| 1702 |
+
dims=[2, 64, 2],
|
| 1703 |
+
vals=np.random.rand(2, 64, 2).astype(np.float32),
|
| 1704 |
+
)
|
| 1705 |
+
frame_step = helper.make_tensor(
|
| 1706 |
+
"b", TensorProto.INT64, dims=[1], vals=np.array([8])
|
| 1707 |
+
)
|
| 1708 |
+
window = helper.make_tensor(
|
| 1709 |
+
"c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
|
| 1710 |
+
)
|
| 1711 |
+
self._test_op_upgrade(
|
| 1712 |
+
operator_name,
|
| 1713 |
+
17,
|
| 1714 |
+
[[2, 64, 2], [1], [16]],
|
| 1715 |
+
[[2, 7, 16, 2]],
|
| 1716 |
+
[
|
| 1717 |
+
TensorProto.FLOAT,
|
| 1718 |
+
TensorProto.INT64,
|
| 1719 |
+
TensorProto.FLOAT,
|
| 1720 |
+
TensorProto.INT64,
|
| 1721 |
+
],
|
| 1722 |
+
initializer=[signal, frame_step, window],
|
| 1723 |
+
)
|
| 1724 |
+
|
| 1725 |
+
# Complex Onesided
|
| 1726 |
+
signal = helper.make_tensor(
|
| 1727 |
+
"a",
|
| 1728 |
+
TensorProto.FLOAT,
|
| 1729 |
+
dims=[2, 64, 2],
|
| 1730 |
+
vals=np.random.rand(2, 64, 2).astype(np.float32),
|
| 1731 |
+
)
|
| 1732 |
+
frame_step = helper.make_tensor(
|
| 1733 |
+
"b", TensorProto.INT64, dims=[1], vals=np.array([8])
|
| 1734 |
+
)
|
| 1735 |
+
window = helper.make_tensor(
|
| 1736 |
+
"c", TensorProto.FLOAT, dims=[16], vals=np.ones(16).astype(np.float32)
|
| 1737 |
+
)
|
| 1738 |
+
frame_length = helper.make_tensor(
|
| 1739 |
+
"e", TensorProto.INT64, dims=[1], vals=np.array([16])
|
| 1740 |
+
)
|
| 1741 |
+
self._test_op_upgrade(
|
| 1742 |
+
operator_name,
|
| 1743 |
+
17,
|
| 1744 |
+
[[2, 64, 2], [1], [16]],
|
| 1745 |
+
[[2, 7, 9, 2]],
|
| 1746 |
+
[
|
| 1747 |
+
TensorProto.FLOAT,
|
| 1748 |
+
TensorProto.INT64,
|
| 1749 |
+
TensorProto.FLOAT,
|
| 1750 |
+
TensorProto.INT64,
|
| 1751 |
+
],
|
| 1752 |
+
attrs={"onesided": 1},
|
| 1753 |
+
initializer=[signal, frame_step, window, frame_length],
|
| 1754 |
+
)
|
| 1755 |
+
|
| 1756 |
+
def test_STFT(self) -> None:
|
| 1757 |
+
self._test_short_time_fourier_transform("STFT")
|
| 1758 |
+
|
| 1759 |
+
def test_MelWeightMatrix(self) -> None:
|
| 1760 |
+
num_mel_bins = helper.make_tensor(
|
| 1761 |
+
"a", TensorProto.INT64, dims=[], vals=np.array([10])
|
| 1762 |
+
)
|
| 1763 |
+
dft_length = helper.make_tensor(
|
| 1764 |
+
"b", TensorProto.INT64, dims=[], vals=np.array([64])
|
| 1765 |
+
)
|
| 1766 |
+
sample_rate = helper.make_tensor(
|
| 1767 |
+
"c", TensorProto.INT64, dims=[], vals=np.array([0])
|
| 1768 |
+
)
|
| 1769 |
+
lower_edge_hertz = helper.make_tensor(
|
| 1770 |
+
"d", TensorProto.FLOAT, dims=[], vals=np.array([0])
|
| 1771 |
+
)
|
| 1772 |
+
upper_edge_hertz = helper.make_tensor(
|
| 1773 |
+
"e", TensorProto.FLOAT, dims=[], vals=np.array([1])
|
| 1774 |
+
)
|
| 1775 |
+
|
| 1776 |
+
self._test_op_upgrade(
|
| 1777 |
+
"MelWeightMatrix",
|
| 1778 |
+
17,
|
| 1779 |
+
[[], [], [], [], []],
|
| 1780 |
+
[[33, 10]],
|
| 1781 |
+
[
|
| 1782 |
+
TensorProto.INT64,
|
| 1783 |
+
TensorProto.INT64,
|
| 1784 |
+
TensorProto.INT64,
|
| 1785 |
+
TensorProto.FLOAT,
|
| 1786 |
+
TensorProto.FLOAT,
|
| 1787 |
+
],
|
| 1788 |
+
initializer=[
|
| 1789 |
+
num_mel_bins,
|
| 1790 |
+
dft_length,
|
| 1791 |
+
sample_rate,
|
| 1792 |
+
lower_edge_hertz,
|
| 1793 |
+
upper_edge_hertz,
|
| 1794 |
+
],
|
| 1795 |
+
)
|
| 1796 |
+
|
| 1797 |
+
num_mel_bins = helper.make_tensor(
|
| 1798 |
+
"a", TensorProto.INT64, dims=[], vals=np.array([20])
|
| 1799 |
+
)
|
| 1800 |
+
dft_length = helper.make_tensor(
|
| 1801 |
+
"b", TensorProto.INT64, dims=[], vals=np.array([31])
|
| 1802 |
+
)
|
| 1803 |
+
sample_rate = helper.make_tensor(
|
| 1804 |
+
"c", TensorProto.INT64, dims=[], vals=np.array([0])
|
| 1805 |
+
)
|
| 1806 |
+
lower_edge_hertz = helper.make_tensor(
|
| 1807 |
+
"d", TensorProto.FLOAT, dims=[], vals=np.array([0])
|
| 1808 |
+
)
|
| 1809 |
+
upper_edge_hertz = helper.make_tensor(
|
| 1810 |
+
"e", TensorProto.FLOAT, dims=[], vals=np.array([1])
|
| 1811 |
+
)
|
| 1812 |
+
|
| 1813 |
+
self._test_op_upgrade(
|
| 1814 |
+
"MelWeightMatrix",
|
| 1815 |
+
17,
|
| 1816 |
+
[[], [], [], [], []],
|
| 1817 |
+
[[16, 20]],
|
| 1818 |
+
[
|
| 1819 |
+
TensorProto.INT64,
|
| 1820 |
+
TensorProto.INT64,
|
| 1821 |
+
TensorProto.INT64,
|
| 1822 |
+
TensorProto.FLOAT,
|
| 1823 |
+
TensorProto.FLOAT,
|
| 1824 |
+
],
|
| 1825 |
+
initializer=[
|
| 1826 |
+
num_mel_bins,
|
| 1827 |
+
dft_length,
|
| 1828 |
+
sample_rate,
|
| 1829 |
+
lower_edge_hertz,
|
| 1830 |
+
upper_edge_hertz,
|
| 1831 |
+
],
|
| 1832 |
+
)
|
| 1833 |
+
|
| 1834 |
+
def test_CenterCropPad(self) -> None:
|
| 1835 |
+
input_ = helper.make_tensor(
|
| 1836 |
+
"input",
|
| 1837 |
+
TensorProto.FLOAT,
|
| 1838 |
+
dims=[2, 4],
|
| 1839 |
+
vals=np.array([1, 2, 3, 4, 5, 6, 7, 8]),
|
| 1840 |
+
)
|
| 1841 |
+
shape = helper.make_tensor(
|
| 1842 |
+
"shape", TensorProto.INT64, dims=[2], vals=np.array([3, 3])
|
| 1843 |
+
)
|
| 1844 |
+
self._test_op_upgrade(
|
| 1845 |
+
"CenterCropPad",
|
| 1846 |
+
18,
|
| 1847 |
+
[[], []],
|
| 1848 |
+
[[3, 3]],
|
| 1849 |
+
[TensorProto.FLOAT, TensorProto.INT64],
|
| 1850 |
+
initializer=[input_, shape],
|
| 1851 |
+
)
|
| 1852 |
+
|
| 1853 |
+
def test_BitwiseNot(self) -> None:
|
| 1854 |
+
self._test_op_upgrade(
|
| 1855 |
+
"BitwiseNot",
|
| 1856 |
+
18,
|
| 1857 |
+
[[2, 3]],
|
| 1858 |
+
[[2, 3]],
|
| 1859 |
+
[TensorProto.INT32],
|
| 1860 |
+
[TensorProto.INT32],
|
| 1861 |
+
)
|
| 1862 |
+
|
| 1863 |
+
def test_BitwiseAnd(self) -> None:
|
| 1864 |
+
self._test_op_upgrade(
|
| 1865 |
+
"BitwiseAnd",
|
| 1866 |
+
18,
|
| 1867 |
+
[[2, 3], [2, 3]],
|
| 1868 |
+
[[2, 3]],
|
| 1869 |
+
[TensorProto.INT16, TensorProto.INT16],
|
| 1870 |
+
[TensorProto.INT16],
|
| 1871 |
+
)
|
| 1872 |
+
|
| 1873 |
+
def test_BitwiseOr(self) -> None:
|
| 1874 |
+
self._test_op_upgrade(
|
| 1875 |
+
"BitwiseOr",
|
| 1876 |
+
18,
|
| 1877 |
+
[[2, 3], [2, 3]],
|
| 1878 |
+
[[2, 3]],
|
| 1879 |
+
[TensorProto.INT16, TensorProto.INT16],
|
| 1880 |
+
[TensorProto.INT16],
|
| 1881 |
+
)
|
| 1882 |
+
|
| 1883 |
+
def test_BitwiseXor(self) -> None:
|
| 1884 |
+
self._test_op_upgrade(
|
| 1885 |
+
"BitwiseXor",
|
| 1886 |
+
18,
|
| 1887 |
+
[[2, 3], [2, 3]],
|
| 1888 |
+
[[2, 3]],
|
| 1889 |
+
[TensorProto.INT16, TensorProto.INT16],
|
| 1890 |
+
[TensorProto.INT16],
|
| 1891 |
+
)
|
| 1892 |
+
|
| 1893 |
+
def test_GroupNormalization(self) -> None:
|
| 1894 |
+
self._test_op_upgrade(
|
| 1895 |
+
"GroupNormalization",
|
| 1896 |
+
21,
|
| 1897 |
+
[[3, 4, 2, 2], [4], [4]],
|
| 1898 |
+
[[3, 4, 2, 2]],
|
| 1899 |
+
attrs={"epsilon": 1e-5, "num_groups": 2},
|
| 1900 |
+
)
|
| 1901 |
+
|
| 1902 |
+
def test_StringConcat(self) -> None:
|
| 1903 |
+
self._test_op_upgrade(
|
| 1904 |
+
"StringConcat",
|
| 1905 |
+
20,
|
| 1906 |
+
[[2, 3], [2, 3]],
|
| 1907 |
+
[[2, 3]],
|
| 1908 |
+
)
|
| 1909 |
+
|
| 1910 |
+
def test_RegexFullMatch(self) -> None:
|
| 1911 |
+
self._test_op_upgrade(
|
| 1912 |
+
"RegexFullMatch",
|
| 1913 |
+
20,
|
| 1914 |
+
[[2, 3]],
|
| 1915 |
+
[[2, 3]],
|
| 1916 |
+
[TensorProto.STRING],
|
| 1917 |
+
[TensorProto.BOOL],
|
| 1918 |
+
)
|
| 1919 |
+
|
| 1920 |
+
def test_TensorScatter(self) -> None:
|
| 1921 |
+
self._test_op_upgrade(
|
| 1922 |
+
"TensorScatter",
|
| 1923 |
+
24,
|
| 1924 |
+
[
|
| 1925 |
+
[2, 3, 4, 5],
|
| 1926 |
+
[2, 3, 2, 5],
|
| 1927 |
+
[
|
| 1928 |
+
2,
|
| 1929 |
+
],
|
| 1930 |
+
],
|
| 1931 |
+
[[2, 3, 4, 5]],
|
| 1932 |
+
[TensorProto.FLOAT, TensorProto.FLOAT, TensorProto.INT64],
|
| 1933 |
+
[TensorProto.FLOAT],
|
| 1934 |
+
)
|
| 1935 |
+
|
| 1936 |
+
def test_ops_tested(self) -> None:
|
| 1937 |
+
# NOTE: This test is order dependent and needs to run last in this class
|
| 1938 |
+
all_schemas = onnx.defs.get_all_schemas()
|
| 1939 |
+
all_op_names = {schema.name for schema in all_schemas if schema.domain == ""}
|
| 1940 |
+
excluded_ops = {
|
| 1941 |
+
# Sequence-based and Optional-based ops disabled because
|
| 1942 |
+
# the version converter doesn't play nicely with sequences
|
| 1943 |
+
"ConcatFromSequence",
|
| 1944 |
+
"SequenceAt",
|
| 1945 |
+
"SequenceConstruct",
|
| 1946 |
+
"SequenceEmpty",
|
| 1947 |
+
"SequenceErase",
|
| 1948 |
+
"SequenceInsert",
|
| 1949 |
+
"SequenceLength",
|
| 1950 |
+
"SequenceMap",
|
| 1951 |
+
"SplitToSequence",
|
| 1952 |
+
"Optional",
|
| 1953 |
+
"OptionalGetElement",
|
| 1954 |
+
"OptionalHasElement",
|
| 1955 |
+
"StringSplit",
|
| 1956 |
+
}
|
| 1957 |
+
expected_tested_ops = all_op_names - excluded_ops
|
| 1958 |
+
|
| 1959 |
+
untested_ops = expected_tested_ops - set(self.tested_ops)
|
| 1960 |
+
self.assertEqual(untested_ops, set())
|
| 1961 |
+
|
| 1962 |
+
|
| 1963 |
+
if __name__ == "__main__":
|
| 1964 |
+
unittest.main()
|
pythonProject/.venv/Lib/site-packages/onnxscript/converter.py
ADDED
|
@@ -0,0 +1,1462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import logging
|
| 7 |
+
from typing import (
|
| 8 |
+
TYPE_CHECKING,
|
| 9 |
+
Any,
|
| 10 |
+
Dict,
|
| 11 |
+
List,
|
| 12 |
+
NoReturn,
|
| 13 |
+
Optional,
|
| 14 |
+
Sequence,
|
| 15 |
+
Tuple,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import onnx
|
| 20 |
+
|
| 21 |
+
import onnxscript
|
| 22 |
+
from onnxscript import irbuilder, onnx_types, sourceinfo, values
|
| 23 |
+
from onnxscript import type_annotation as ta
|
| 24 |
+
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger("onnxscript")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Python-to-IR converter:
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def not_allowed(construct):
|
| 33 |
+
return f"{construct}not supported."
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TranslationError(Exception):
|
| 37 |
+
def __init__(self, *args: object) -> None:
|
| 38 |
+
super().__init__(*args)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def warn(msg):
|
| 42 |
+
logger.warning(msg)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def fail(msg) -> NoReturn:
|
| 46 |
+
raise TranslationError(msg)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fail_if(cond, msg):
|
| 50 |
+
if cond:
|
| 51 |
+
raise TranslationError(msg)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def ignore(cond, msg):
|
| 55 |
+
if cond:
|
| 56 |
+
warn(msg)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# map from python operators to ONNX ops
|
| 60 |
+
primop_map = {
|
| 61 |
+
ast.Add: "Add",
|
| 62 |
+
ast.And: "And",
|
| 63 |
+
ast.BitAnd: "And",
|
| 64 |
+
ast.BitOr: "Or",
|
| 65 |
+
ast.Div: "Div",
|
| 66 |
+
ast.Eq: "Equal",
|
| 67 |
+
ast.Gt: "Greater",
|
| 68 |
+
ast.GtE: "GreaterOrEqual",
|
| 69 |
+
ast.Lt: "Less",
|
| 70 |
+
ast.LtE: "LessOrEqual",
|
| 71 |
+
ast.MatMult: "MatMul",
|
| 72 |
+
ast.Mod: "Mod",
|
| 73 |
+
ast.Mult: "Mul",
|
| 74 |
+
ast.Not: "Not",
|
| 75 |
+
ast.NotEq: "NotEqual",
|
| 76 |
+
ast.Or: "Or",
|
| 77 |
+
ast.Pow: "Pow",
|
| 78 |
+
ast.Sub: "Sub",
|
| 79 |
+
ast.USub: "Neg",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Variable:
|
| 84 |
+
"""Represents an ONNX variable.
|
| 85 |
+
|
| 86 |
+
TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this
|
| 87 |
+
converter.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, name: str, castable: bool = False):
|
| 91 |
+
"""Initialize the instance.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
name: Name of the ONNX variable
|
| 95 |
+
castable: Whether this variable is castable to a desired target type.
|
| 96 |
+
Used for ONNX variables representing constants created from python values
|
| 97 |
+
like 0 or 1 or 0.5 which are treated as polymorphic values castable to other
|
| 98 |
+
types as needed.
|
| 99 |
+
"""
|
| 100 |
+
self.name = name
|
| 101 |
+
self.is_castable = castable
|
| 102 |
+
|
| 103 |
+
def __str__(self) -> str:
|
| 104 |
+
return self.name
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if TYPE_CHECKING:
|
| 108 |
+
# The type-alias LocalSymValue represents the types of values that local names in a
|
| 109 |
+
# script-function may be bound to during translation, (ONNX IR values).
|
| 110 |
+
# TODO(rama): Rationalize this and values.SymbolValue
|
| 111 |
+
|
| 112 |
+
LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction]
|
| 113 |
+
|
| 114 |
+
# The type-alias PyValue is used to represent the types of python values that may be used
|
| 115 |
+
# in an ONNX Script function.
|
| 116 |
+
# TODO(rama): Flesh out the set of valid types here. These include values such as
|
| 117 |
+
# 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for
|
| 118 |
+
# use as value-parameters or attribute-parameters in an ONNX call (Node).
|
| 119 |
+
|
| 120 |
+
PyValue = Any
|
| 121 |
+
|
| 122 |
+
# The type-alias SymValue denotes values that an identifier may be bound to during
|
| 123 |
+
# translation. A local name will be bound to a LocalSymValue, while a global name
|
| 124 |
+
# will be bound to a PyValue.
|
| 125 |
+
|
| 126 |
+
SymValue = Union[LocalSymValue, PyValue]
|
| 127 |
+
|
| 128 |
+
# PreferredName is a type-alias used to represent the preferred name used in the generated
|
| 129 |
+
# ONNX for a value returned by an expression. There is no guarantee that the specified
|
| 130 |
+
# name will be used exactly. The converter will modify the name (with a suffix),
|
| 131 |
+
# if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement).
|
| 132 |
+
|
| 133 |
+
PreferredName = str
|
| 134 |
+
|
| 135 |
+
# The type-alias OnnxVar indicates variable names used in the generated ONNX.
|
| 136 |
+
OnnxVarName = str
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Converter:
|
| 140 |
+
"""Main class to translate python code into ONNX operators.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
ir_builder: convert AST node into ONNX structures, if None,
|
| 144 |
+
class :class:`onnxscript.irbuilder.IRBuilder` is used
|
| 145 |
+
|
| 146 |
+
The class uses logger `onnxscript`. Logging can be enabled with the following code:
|
| 147 |
+
|
| 148 |
+
::
|
| 149 |
+
|
| 150 |
+
import logging
|
| 151 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 152 |
+
|
| 153 |
+
Or if you need to enable only the logger used by this module:
|
| 154 |
+
|
| 155 |
+
::
|
| 156 |
+
|
| 157 |
+
import logging
|
| 158 |
+
logger = logging.getLogger('onnxscript')
|
| 159 |
+
logger.setLevel(logging.DEBUG)
|
| 160 |
+
console = logging.StreamHandler()
|
| 161 |
+
logger.addHandler(console)
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
ir_builder: Optional[irbuilder.IRBuilder] = None,
|
| 167 |
+
opset: Optional[values.Opset] = None,
|
| 168 |
+
global_names: Optional[dict[str, Any]] = None,
|
| 169 |
+
source: Optional[str] = None,
|
| 170 |
+
default_opset: Optional[values.Opset] = None,
|
| 171 |
+
):
|
| 172 |
+
self.ir_builder = ir_builder or irbuilder.IRBuilder()
|
| 173 |
+
self.source = source
|
| 174 |
+
if global_names is not None:
|
| 175 |
+
# We make a copy in case function eval modifies it.
|
| 176 |
+
self.globals = global_names.copy()
|
| 177 |
+
self.this_module = opset
|
| 178 |
+
self.default_opset_ = default_opset
|
| 179 |
+
|
| 180 |
+
# States initialized by `_init_function_translation`
|
| 181 |
+
self._outer: List[irbuilder.IRFunction] = []
|
| 182 |
+
self._current_fn: irbuilder.IRFunction = None
|
| 183 |
+
self._nextvar: int = 0
|
| 184 |
+
self._used_vars: set[str] = set()
|
| 185 |
+
self._locals: List[Dict[str, LocalSymValue]] = [{}]
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def default_opset(self) -> values.Opset:
|
| 189 |
+
if self.default_opset_ is None:
|
| 190 |
+
raise RuntimeError(
|
| 191 |
+
"default_opset must be specified in script for functions "
|
| 192 |
+
"that do not contain any use of an ONNX opset."
|
| 193 |
+
)
|
| 194 |
+
return self.default_opset_
|
| 195 |
+
|
| 196 |
+
def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
|
| 197 |
+
if opset.domain != "":
|
| 198 |
+
return
|
| 199 |
+
if self.default_opset_ is not None:
|
| 200 |
+
if (
|
| 201 |
+
opset.domain != self.default_opset_.domain
|
| 202 |
+
or opset.version != self.default_opset_.version
|
| 203 |
+
):
|
| 204 |
+
self.fail(
|
| 205 |
+
node, f"Two distincts opset were used ({opset} != {self.default_opset_})."
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
self.default_opset_ = opset
|
| 209 |
+
|
| 210 |
+
def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
|
| 211 |
+
"""Find the (first) ONNX opset used in the function, if any."""
|
| 212 |
+
# Search for a Call expression of form "op.OpName(...)"
|
| 213 |
+
if isinstance(node, ast.Call):
|
| 214 |
+
if isinstance(node.func, ast.Attribute):
|
| 215 |
+
opset_expr = node.func.value
|
| 216 |
+
if isinstance(opset_expr, ast.Name):
|
| 217 |
+
if opset_expr.id in self.globals:
|
| 218 |
+
opset = self.globals[opset_expr.id]
|
| 219 |
+
if isinstance(opset, values.Opset) and opset.domain == "":
|
| 220 |
+
return opset
|
| 221 |
+
for child in ast.iter_child_nodes(node):
|
| 222 |
+
res = self._find_onnx_opset(child)
|
| 223 |
+
if res is not None:
|
| 224 |
+
return res
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
def _init_function_translation(self) -> None:
|
| 228 |
+
"""Initialize self for translating a new (top-level) function."""
|
| 229 |
+
self._outer = []
|
| 230 |
+
self._current_fn: Optional[irbuilder.IRFunction] = None
|
| 231 |
+
self._nextvar = 0
|
| 232 |
+
self._used_vars = set()
|
| 233 |
+
self._locals: List[Dict[str, LocalSymValue]] = [{}]
|
| 234 |
+
|
| 235 |
+
def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
|
| 236 |
+
return sourceinfo.SourceInfo(node, self.source, self._current_fn.name)
|
| 237 |
+
|
| 238 |
+
def _message(self, node: ast.AST, error_msg: str) -> str:
|
| 239 |
+
"""Constructs an error _message containing source information about an ast node."""
|
| 240 |
+
return self._source_of(node).msg(error_msg)
|
| 241 |
+
|
| 242 |
+
def warn(self, node: ast.AST, error_msg: str) -> None:
|
| 243 |
+
warn(self._message(node, error_msg))
|
| 244 |
+
|
| 245 |
+
def fail(self, node: ast.AST, error_msg: str) -> NoReturn:
|
| 246 |
+
fail(self._message(node, error_msg))
|
| 247 |
+
|
| 248 |
+
# Name resolution and namescopes: This component handles the following aspects:
|
| 249 |
+
# * Name-scopes are different in Python and the generated ONNX:
|
| 250 |
+
# - Control-flow blocks (a loop body or the then-or-else block of an if-stmt)
|
| 251 |
+
# form part of the same name-scope in Python, but will be mapped to a nested
|
| 252 |
+
# name-scope (as a sub-graph) in ONNX.
|
| 253 |
+
# * Script-time name-value tracking: Name _lookup during script-time returns
|
| 254 |
+
# statically-known information about the value the name will have at runtime.
|
| 255 |
+
def _enter_scope(self, name: str, parent_node: ast.AST):
|
| 256 |
+
"""Enter a control-flow block (a loop body or if-then-else branch).
|
| 257 |
+
The block is translated into a nested-scope in ONNX.
|
| 258 |
+
"""
|
| 259 |
+
self._outer.insert(0, self._current_fn)
|
| 260 |
+
self._current_fn = self.ir_builder.new_function(name)
|
| 261 |
+
self._locals.insert(0, {})
|
| 262 |
+
logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node))
|
| 263 |
+
|
| 264 |
+
def _exit_scope(self) -> irbuilder.IRFunction:
|
| 265 |
+
"""Exit from a control-flow block (a loop body or if-then-else branch)."""
|
| 266 |
+
logger.debug("Converter:_exit_scope:%d", len(self._locals))
|
| 267 |
+
graph = self._current_fn
|
| 268 |
+
self._current_fn = self._outer.pop(0)
|
| 269 |
+
self._locals.pop(0)
|
| 270 |
+
return graph
|
| 271 |
+
|
| 272 |
+
def _current_scope(self) -> Dict[str, LocalSymValue]:
|
| 273 |
+
return self._locals[0]
|
| 274 |
+
|
| 275 |
+
def _bind(self, name: str, val: LocalSymValue) -> None:
|
| 276 |
+
logger.debug("Converter:_bind:%s", name)
|
| 277 |
+
self._locals[0][name] = val
|
| 278 |
+
|
| 279 |
+
def _lookup(
|
| 280 |
+
self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True
|
| 281 |
+
) -> SymValue:
|
| 282 |
+
for scope in self._locals:
|
| 283 |
+
if name in scope:
|
| 284 |
+
return scope[name]
|
| 285 |
+
if name in self.globals:
|
| 286 |
+
return self.globals[name]
|
| 287 |
+
if raise_exception:
|
| 288 |
+
raise ValueError(info.msg(f"Unbound name: {name}."))
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
def generate_unique_name(self, candidate: str = "tmp") -> str:
|
| 292 |
+
# TODO(justinchuby): Can we reduce the O complexity of this function?
|
| 293 |
+
r = candidate
|
| 294 |
+
while r in self._used_vars:
|
| 295 |
+
r = f"{candidate}_{self._nextvar}"
|
| 296 |
+
self._nextvar = self._nextvar + 1
|
| 297 |
+
self._used_vars.add(r)
|
| 298 |
+
return r
|
| 299 |
+
|
| 300 |
+
def _make_onnx_attr(
|
| 301 |
+
self, attrname: str, attrval: Any, attrtype: int | None = None
|
| 302 |
+
) -> irbuilder.IRAttributeValue:
|
| 303 |
+
def tensor_name_generator() -> str:
|
| 304 |
+
"""Return name to be used for tensor, if we need to create one."""
|
| 305 |
+
return self.generate_unique_name(f"attr_{attrname}")
|
| 306 |
+
|
| 307 |
+
proto = autocast.pyvalue_to_onnx_attribute(
|
| 308 |
+
attrname, attrval, tensor_name_generator, attrtype
|
| 309 |
+
)
|
| 310 |
+
return self.ir_builder.make_attr(proto)
|
| 311 |
+
|
| 312 |
+
def _to_onnx_attr_ref(
|
| 313 |
+
self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]
|
| 314 |
+
) -> irbuilder.IRAttributeValue:
|
| 315 |
+
pytype = val.typeinfo
|
| 316 |
+
attrtype = ta.pytype_to_attrtype(pytype)
|
| 317 |
+
attrname = None
|
| 318 |
+
if attrtype is onnx.AttributeProto.FLOAT:
|
| 319 |
+
attrname = "value_float"
|
| 320 |
+
elif attrtype is onnx.AttributeProto.INT:
|
| 321 |
+
attrname = "value_int"
|
| 322 |
+
elif attrtype is onnx.AttributeProto.STRING:
|
| 323 |
+
attrname = "value_string"
|
| 324 |
+
elif attrtype is onnx.AttributeProto.INTS:
|
| 325 |
+
attrname = "value_ints"
|
| 326 |
+
else:
|
| 327 |
+
msg = f"Unsupported attribute type {pytype!r}."
|
| 328 |
+
fail(info.msg(msg) if info else msg)
|
| 329 |
+
return self.ir_builder.make_attr_ref(attrname, val.value, pytype)
|
| 330 |
+
|
| 331 |
+
def _to_onnx_var(
|
| 332 |
+
self,
|
| 333 |
+
val: values.SymbolValue | PyValue,
|
| 334 |
+
target: Optional[PreferredName] = None,
|
| 335 |
+
info: Optional[sourceinfo.SourceInfo] = None,
|
| 336 |
+
) -> Variable:
|
| 337 |
+
if isinstance(val, values.AttrRef):
|
| 338 |
+
# promote attribute to value
|
| 339 |
+
result = self.generate_unique_name(target or "tmp")
|
| 340 |
+
attr = self._to_onnx_attr_ref(val, info)
|
| 341 |
+
self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr])
|
| 342 |
+
if ta.base_type_is_bool(val.typeinfo):
|
| 343 |
+
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
|
| 344 |
+
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
|
| 345 |
+
# to promote a (python) bool attribute to a ONNX bool tensor.
|
| 346 |
+
result_as_bool = self.generate_unique_name(result + "_as_bool")
|
| 347 |
+
cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype)
|
| 348 |
+
self.emit(
|
| 349 |
+
[result_as_bool],
|
| 350 |
+
values.Op(self.default_opset, "Cast"),
|
| 351 |
+
[result],
|
| 352 |
+
[cast_attr],
|
| 353 |
+
)
|
| 354 |
+
return Variable(result_as_bool, True)
|
| 355 |
+
return Variable(result, True)
|
| 356 |
+
if isinstance(val, values.Dynamic):
|
| 357 |
+
return Variable(val.value)
|
| 358 |
+
# Assume value is a python-value convertible to a tensor
|
| 359 |
+
# TODO: check if value is convertible to a TensorProto, so that we can
|
| 360 |
+
# produce a better error _message otherwise
|
| 361 |
+
return self._emit_const(val, target or "tmp", info)
|
| 362 |
+
|
| 363 |
+
def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable:
|
| 364 |
+
return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info)
|
| 365 |
+
|
| 366 |
+
def emit(
|
| 367 |
+
self,
|
| 368 |
+
outputs: Sequence[str],
|
| 369 |
+
callee: values.Op | str,
|
| 370 |
+
inputs: Sequence[Optional[str]],
|
| 371 |
+
attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None,
|
| 372 |
+
sub_functions: Optional[dict[str, onnx.FunctionProto]] = None,
|
| 373 |
+
):
|
| 374 |
+
if not isinstance(callee, values.Op):
|
| 375 |
+
callee = values.Op(self.default_opset, callee)
|
| 376 |
+
if attrs is None:
|
| 377 |
+
attrs = []
|
| 378 |
+
if sub_functions is None:
|
| 379 |
+
sub_functions = {}
|
| 380 |
+
self.ir_builder.add_stmt(
|
| 381 |
+
self._current_fn,
|
| 382 |
+
outputs,
|
| 383 |
+
callee,
|
| 384 |
+
inputs,
|
| 385 |
+
attrs,
|
| 386 |
+
sub_functions,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def _emit_const(
|
| 390 |
+
self,
|
| 391 |
+
pyvalue: PyValue,
|
| 392 |
+
suggested_name: Optional[PreferredName],
|
| 393 |
+
info: sourceinfo.SourceInfo,
|
| 394 |
+
) -> Variable:
|
| 395 |
+
if suggested_name is None:
|
| 396 |
+
if isinstance(pyvalue, int):
|
| 397 |
+
if pyvalue >= 0:
|
| 398 |
+
suggested_name = f"int64_{pyvalue}"
|
| 399 |
+
else:
|
| 400 |
+
suggested_name = f"int64_m{abs(pyvalue)}"
|
| 401 |
+
elif (
|
| 402 |
+
isinstance(pyvalue, list) and len(pyvalue) == 1 and isinstance(pyvalue[0], int)
|
| 403 |
+
):
|
| 404 |
+
if pyvalue[0] >= 0:
|
| 405 |
+
suggested_name = f"int64_{pyvalue[0]}_1d"
|
| 406 |
+
else:
|
| 407 |
+
suggested_name = f"int64_m{abs(pyvalue[0])}_1d"
|
| 408 |
+
else:
|
| 409 |
+
suggested_name = "const"
|
| 410 |
+
ovar = self.generate_unique_name(suggested_name)
|
| 411 |
+
try:
|
| 412 |
+
tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue)
|
| 413 |
+
except ValueError as e:
|
| 414 |
+
fail(info.msg(str(e)))
|
| 415 |
+
attr = self._make_onnx_attr("value", tensor)
|
| 416 |
+
self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr])
|
| 417 |
+
return Variable(ovar, True)
|
| 418 |
+
|
| 419 |
+
def _emit_copy(self, original_var: str, suggested_name: str) -> str:
|
| 420 |
+
"""Emits a copy statement, using the ONNX Identity operator."""
|
| 421 |
+
new_var = self.generate_unique_name(suggested_name)
|
| 422 |
+
self.emit([new_var], "Identity", [original_var])
|
| 423 |
+
return new_var
|
| 424 |
+
|
| 425 |
+
def _is_constant_expr(self, node: ast.AST) -> None:
|
| 426 |
+
if isinstance(node, ast.UnaryOp):
|
| 427 |
+
return self._is_constant_expr(node.operand)
|
| 428 |
+
if isinstance(
|
| 429 |
+
node,
|
| 430 |
+
(
|
| 431 |
+
ast.Call,
|
| 432 |
+
ast.BinOp,
|
| 433 |
+
ast.UnaryOp,
|
| 434 |
+
ast.Compare,
|
| 435 |
+
ast.Attribute,
|
| 436 |
+
ast.List,
|
| 437 |
+
ast.Load,
|
| 438 |
+
ast.Constant,
|
| 439 |
+
),
|
| 440 |
+
):
|
| 441 |
+
return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node))
|
| 442 |
+
return False
|
| 443 |
+
|
| 444 |
+
def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
|
| 445 |
+
"""Evaluates a sub-expression that is assumed to represent a constant value.
|
| 446 |
+
The expression can refer only to global names (inherited from the scope
|
| 447 |
+
where the script is evaluated) and cannot refer to local names defined
|
| 448 |
+
within the script.) Further, these expressions are assumed to be constants.
|
| 449 |
+
Thus, any subsequent mutation of any state/variables (used in computing
|
| 450 |
+
this constant value) will potentially lead to unexpected behavior (such
|
| 451 |
+
as divergence between eager-mode execution and evaluation of the ONNX
|
| 452 |
+
function.)
|
| 453 |
+
"""
|
| 454 |
+
# TODO: assert (self._is_constant_expr(expr))
|
| 455 |
+
# TODO: Refine types
|
| 456 |
+
locals: dict[Any, Any] = {}
|
| 457 |
+
expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset)
|
| 458 |
+
cpl = compile(expr, filename="<ast>", mode="eval")
|
| 459 |
+
try:
|
| 460 |
+
return eval(cpl, self.globals, locals) # pylint: disable=eval-used
|
| 461 |
+
except NameError as e:
|
| 462 |
+
raise NameError(
|
| 463 |
+
self._message(
|
| 464 |
+
expr,
|
| 465 |
+
f"Missing names, globals contains {list(self.globals)!r}, "
|
| 466 |
+
f"locals {list(locals)!r}.",
|
| 467 |
+
)
|
| 468 |
+
) from e
|
| 469 |
+
|
| 470 |
+
def _translate_attr(
|
| 471 |
+
self,
|
| 472 |
+
attr_name: str,
|
| 473 |
+
expr: ast.AST,
|
| 474 |
+
attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None,
|
| 475 |
+
) -> Optional[irbuilder.IRAttributeValue]:
|
| 476 |
+
"""Translate an attribute-value specification of the form `attr_name=<expr>`
|
| 477 |
+
in a call to an op. expr is an AST. The following cases are supported:
|
| 478 |
+
* Expr evaluates to a script-time constant (a python-value) that can be mapped
|
| 479 |
+
into an ONNX attribute value, or
|
| 480 |
+
* Expr evaluates to None, in which case None is returned, or
|
| 481 |
+
* Expr must be an attribute-reference, that is a name representing an
|
| 482 |
+
attribute-parameter of a containing function.
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
if isinstance(expr, ast.Name):
|
| 486 |
+
val = self._lookup(expr.id, self._source_of(expr))
|
| 487 |
+
if isinstance(val, values.AttrRef):
|
| 488 |
+
attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo)
|
| 489 |
+
if attr_meta is not None and (attr_ref.type != attr_meta.type):
|
| 490 |
+
self.fail(
|
| 491 |
+
expr,
|
| 492 |
+
f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
|
| 493 |
+
)
|
| 494 |
+
return attr_ref
|
| 495 |
+
if isinstance(val, irbuilder.IRFunction):
|
| 496 |
+
# Check that outer-scope variables referenced by function have same value
|
| 497 |
+
# at function-definition site and use-as-attribute site, to avoid errors.
|
| 498 |
+
for pyvar, previous in val.outer_scope_variables:
|
| 499 |
+
current = self._lookup(pyvar, self._source_of(expr))
|
| 500 |
+
if current.value != previous.value:
|
| 501 |
+
self.fail(
|
| 502 |
+
expr,
|
| 503 |
+
f"Outer scope variable '{pyvar}' referenced by function "
|
| 504 |
+
f"'{expr.id!r}' modified.",
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# Create GraphProto attribute
|
| 508 |
+
val = val.to_graph_proto()
|
| 509 |
+
else:
|
| 510 |
+
val = self._eval_constant_expr(expr)
|
| 511 |
+
|
| 512 |
+
# In ONNX, there is no way to explicitly specify a None value for an attribute.
|
| 513 |
+
# Instead, the attribute must be omitted from the attribute list.
|
| 514 |
+
# Hence, we do not create an attribute-proto if the value is None.
|
| 515 |
+
# The caller is responsible for omitting such attribute-values from the list of attributes
|
| 516 |
+
# in a NodeProto.
|
| 517 |
+
if val is None:
|
| 518 |
+
if attr_meta and attr_meta.required:
|
| 519 |
+
self.fail(expr, f"Attribute '{attr_name}' is required.")
|
| 520 |
+
return None
|
| 521 |
+
attr_type = int(attr_meta.type) if attr_meta else None
|
| 522 |
+
attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type)
|
| 523 |
+
if attr_meta and (attr.type != attr_meta.type):
|
| 524 |
+
self.fail(
|
| 525 |
+
expr,
|
| 526 |
+
f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'",
|
| 527 |
+
)
|
| 528 |
+
return attr
|
| 529 |
+
|
| 530 |
+
def _translate_docstring(self, node: ast.Expr) -> None:
|
| 531 |
+
if hasattr(node.value, "value"):
|
| 532 |
+
# python 3.8+
|
| 533 |
+
return self.ir_builder.add_docstring(self._current_fn, node.value.value)
|
| 534 |
+
raise TypeError(
|
| 535 |
+
f"Unexpected type {type(node)!r} for node. Unsupoorted version of python."
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
def _translate_expr(
|
| 539 |
+
self, node: ast.AST, target: Optional[PreferredName] = None
|
| 540 |
+
) -> Variable:
|
| 541 |
+
"""Expression-translation generates "IR statements/nodes" that compute the value of
|
| 542 |
+
the expression into a target-variable, and returns the variable that is
|
| 543 |
+
assigned this value.
|
| 544 |
+
"""
|
| 545 |
+
if isinstance(node, ast.Call):
|
| 546 |
+
r = self._translate_call_expr(node)
|
| 547 |
+
elif isinstance(node, (ast.BinOp, ast.BitAnd, ast.BitOr)):
|
| 548 |
+
r = self._translate_binary_op_expr(node)
|
| 549 |
+
elif isinstance(node, ast.UnaryOp):
|
| 550 |
+
r = self._translate_unary_op_expr(node)
|
| 551 |
+
elif isinstance(node, ast.Compare):
|
| 552 |
+
r = self._translate_compare_expr(node)
|
| 553 |
+
elif isinstance(node, ast.Name):
|
| 554 |
+
r = self._translate_name_expr(node)
|
| 555 |
+
elif isinstance(node, ast.Subscript):
|
| 556 |
+
r = self._translate_subscript_expr(node, target)
|
| 557 |
+
elif self._is_constant_expr(node):
|
| 558 |
+
r = self._emit_const(self._eval_constant_expr(node), target, self._source_of(node))
|
| 559 |
+
else:
|
| 560 |
+
raise ValueError(
|
| 561 |
+
self._message(node, f"Unsupported expression type {type(node)!r}.")
|
| 562 |
+
)
|
| 563 |
+
if isinstance(r, Variable):
|
| 564 |
+
return r
|
| 565 |
+
callee, args, attrs = r
|
| 566 |
+
target = "tmp" if target is None else target
|
| 567 |
+
assert isinstance(target, str)
|
| 568 |
+
result = self.generate_unique_name(target)
|
| 569 |
+
self.emit([result], callee, args, attrs)
|
| 570 |
+
return Variable(result)
|
| 571 |
+
|
| 572 |
+
def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]:
|
| 573 |
+
"""Translation of an expression where "None" is permitted (eg., for an optional argument).
|
| 574 |
+
None is represented as a Constant in Python 3.9+.
|
| 575 |
+
"""
|
| 576 |
+
if isinstance(node, ast.Constant) and (node.value is None):
|
| 577 |
+
return None
|
| 578 |
+
return self._translate_expr(node)
|
| 579 |
+
|
| 580 |
+
def _translate_subscript_expr(
|
| 581 |
+
self, node: ast.Subscript, target: Optional[PreferredName]
|
| 582 |
+
) -> Variable:
|
| 583 |
+
"""List of supported syntaxes is below.
|
| 584 |
+
`A` is a tensor or an expression equivalent to a tensor.
|
| 585 |
+
|
| 586 |
+
::
|
| 587 |
+
|
| 588 |
+
A[:, 1]
|
| 589 |
+
A[:2, 0]
|
| 590 |
+
A[:2, :1]
|
| 591 |
+
A[2:0:-1]
|
| 592 |
+
A[1:]
|
| 593 |
+
A[:2]
|
| 594 |
+
A[1:-1]
|
| 595 |
+
A[1:2]
|
| 596 |
+
A[-1]
|
| 597 |
+
A[0]
|
| 598 |
+
A[:0:-1]
|
| 599 |
+
|
| 600 |
+
*i* is a tensor holding one integer.
|
| 601 |
+
|
| 602 |
+
::
|
| 603 |
+
|
| 604 |
+
A[i]
|
| 605 |
+
A[i+1:i+2]
|
| 606 |
+
|
| 607 |
+
Fully supported for python 3.9+.
|
| 608 |
+
|
| 609 |
+
::
|
| 610 |
+
|
| 611 |
+
A[i:i+j, k]
|
| 612 |
+
|
| 613 |
+
Not supported:
|
| 614 |
+
|
| 615 |
+
::
|
| 616 |
+
|
| 617 |
+
A[::-1]
|
| 618 |
+
"""
|
| 619 |
+
var = self._translate_expr(node.value)
|
| 620 |
+
var_name = var.name
|
| 621 |
+
if target is None:
|
| 622 |
+
target = f"{var_name}_subscripted"
|
| 623 |
+
target = self.generate_unique_name(target)
|
| 624 |
+
indices = ast_utils.normalize_subscript_expr(node)
|
| 625 |
+
info = self._source_of(node.slice)
|
| 626 |
+
|
| 627 |
+
# Create cached int constants:
|
| 628 |
+
# TODO: Do this at a graph-scope level.
|
| 629 |
+
cached_int_consts = {}
|
| 630 |
+
|
| 631 |
+
def const_1d(value, name: Optional[str] = None):
|
| 632 |
+
nonlocal cached_int_consts
|
| 633 |
+
if value not in cached_int_consts:
|
| 634 |
+
cached_int_consts[value] = self._emit_const([value], name, info)
|
| 635 |
+
return cached_int_consts[value]
|
| 636 |
+
|
| 637 |
+
def one_1d():
|
| 638 |
+
return const_1d(1)
|
| 639 |
+
|
| 640 |
+
# Max/min 64-bit int values are used to represent default values for start/stop in Slice.
|
| 641 |
+
maxint = (1 << 63) - 1
|
| 642 |
+
minint = -(1 << 63)
|
| 643 |
+
|
| 644 |
+
def translate_slice_component(
|
| 645 |
+
node_arg, default_value: Optional[int] = None
|
| 646 |
+
) -> tuple[str, Optional[int]]:
|
| 647 |
+
"""Translate optional start/stop/step component of a Slice expression."""
|
| 648 |
+
if node_arg is None:
|
| 649 |
+
if default_value is None:
|
| 650 |
+
# TODO: Emit "Where(step > 0, pos_default, neg_default)"
|
| 651 |
+
raise RuntimeError(
|
| 652 |
+
"Default start/stop not supported when step direction is unknown."
|
| 653 |
+
)
|
| 654 |
+
return const_1d(default_value), default_value
|
| 655 |
+
|
| 656 |
+
if self._is_constant_expr(node_arg):
|
| 657 |
+
cst = self._eval_constant_expr(node_arg)
|
| 658 |
+
if isinstance(cst, int):
|
| 659 |
+
return const_1d(cst), cst
|
| 660 |
+
else:
|
| 661 |
+
raise RuntimeError(f"Slice component type must be int, not {type(cst)}")
|
| 662 |
+
else:
|
| 663 |
+
name = self._translate_expr(node_arg).name
|
| 664 |
+
reshaped = self.generate_unique_name(f"{name}_reshaped")
|
| 665 |
+
self.emit(
|
| 666 |
+
[reshaped],
|
| 667 |
+
values.Op(self.default_opset, "Reshape"),
|
| 668 |
+
[name, one_1d().name],
|
| 669 |
+
[],
|
| 670 |
+
)
|
| 671 |
+
return reshaped, None
|
| 672 |
+
|
| 673 |
+
def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
|
| 674 |
+
"""Translate slice-expression of the form from:to:step."""
|
| 675 |
+
step_name, step = translate_slice_component(slice_expr.step, 1)
|
| 676 |
+
if step is None:
|
| 677 |
+
# Step direction unknown.
|
| 678 |
+
# TODO: Handle default-values using runtime check on sign of step.
|
| 679 |
+
lower_name, _ = translate_slice_component(slice_expr.lower, None)
|
| 680 |
+
upper_name, _ = translate_slice_component(slice_expr.upper, None)
|
| 681 |
+
elif step > 0:
|
| 682 |
+
lower_name, _ = translate_slice_component(slice_expr.lower, 0)
|
| 683 |
+
upper_name, _ = translate_slice_component(slice_expr.upper, maxint)
|
| 684 |
+
else:
|
| 685 |
+
lower_name, _ = translate_slice_component(slice_expr.lower, maxint)
|
| 686 |
+
upper_name, _ = translate_slice_component(slice_expr.upper, minint)
|
| 687 |
+
return (lower_name, upper_name, step_name)
|
| 688 |
+
|
| 689 |
+
# An input like X[2] is translated into a Gather op.
|
| 690 |
+
# An input like X[1:5:2] is translated into a Slice op.
|
| 691 |
+
# An input like X[2, 3] is translated into a Slice + Squeeze (instead of two Gathers),
|
| 692 |
+
# as an optimization.
|
| 693 |
+
# An input like X[I, J] is translated into two Gathers (which is correct whatever the
|
| 694 |
+
# rank of I and J)
|
| 695 |
+
# To replace multiple Gathers by the Slice we need to know that the index-values
|
| 696 |
+
# are scalars.
|
| 697 |
+
|
| 698 |
+
# As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2),
|
| 699 |
+
# known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":")
|
| 700 |
+
sliced_indices: List[Tuple[int, ast.expr]] = []
|
| 701 |
+
scalar_indices: List[Tuple[int, ast.expr]] = []
|
| 702 |
+
non_scalar_indices: List[Tuple[int, ast.expr]] = []
|
| 703 |
+
for axis, elt in enumerate(indices):
|
| 704 |
+
if isinstance(elt, ast.Slice):
|
| 705 |
+
# Add to sliced_indices, unless it is "::", which is a no-op.
|
| 706 |
+
if not (elt.lower is None and elt.upper is None and elt.step is None):
|
| 707 |
+
sliced_indices.append((axis, elt))
|
| 708 |
+
elif self._is_constant_expr(elt) and isinstance(
|
| 709 |
+
self._eval_constant_expr(elt), int
|
| 710 |
+
):
|
| 711 |
+
scalar_indices.append((axis, elt))
|
| 712 |
+
else:
|
| 713 |
+
non_scalar_indices.append((axis, elt))
|
| 714 |
+
if not (sliced_indices or scalar_indices or non_scalar_indices):
|
| 715 |
+
# Edge case: no index specified. Eg. A[:, :]
|
| 716 |
+
self.emit([target], "Identity", [var_name])
|
| 717 |
+
return Variable(target)
|
| 718 |
+
if sliced_indices or len(scalar_indices) > 1:
|
| 719 |
+
# We emit a Slice operation if we have any indices like 1:5:2 or if the number of
|
| 720 |
+
# scalar indices (like 2) is more than 1.
|
| 721 |
+
starts = []
|
| 722 |
+
ends = []
|
| 723 |
+
axes = []
|
| 724 |
+
steps = []
|
| 725 |
+
squeezed_axes = []
|
| 726 |
+
for axis, expr in scalar_indices:
|
| 727 |
+
# Treat a scalar index i as slice "i:i+1:1", but squeeze the axis finally.
|
| 728 |
+
# TODO: handle negative i
|
| 729 |
+
index = self._eval_constant_expr(expr)
|
| 730 |
+
squeezed_axes.append(axis)
|
| 731 |
+
kwargs = dict(
|
| 732 |
+
lineno=getattr(expr, "lineno", node.lineno),
|
| 733 |
+
col_offset=getattr(expr, "col_offset", node.col_offset),
|
| 734 |
+
)
|
| 735 |
+
element = ast.Slice(
|
| 736 |
+
ast.Constant(index, **kwargs),
|
| 737 |
+
ast.Constant(index + 1, **kwargs),
|
| 738 |
+
ast.Constant(1, **kwargs),
|
| 739 |
+
)
|
| 740 |
+
sliced_indices.append((axis, element))
|
| 741 |
+
scalar_indices = []
|
| 742 |
+
for axis, element in sliced_indices:
|
| 743 |
+
axis_var = const_1d(axis)
|
| 744 |
+
inputs = translate_slice(element)
|
| 745 |
+
starts.append(inputs[0])
|
| 746 |
+
ends.append(inputs[1])
|
| 747 |
+
axes.append(axis_var.name)
|
| 748 |
+
steps.append(inputs[2])
|
| 749 |
+
|
| 750 |
+
if len(starts) > 1:
|
| 751 |
+
axis_0_attr = self._make_onnx_attr("axis", 0)
|
| 752 |
+
start_name = self.generate_unique_name(f"{var_name}_start")
|
| 753 |
+
self.emit([start_name], "Concat", starts, [axis_0_attr])
|
| 754 |
+
|
| 755 |
+
end_name = self.generate_unique_name(f"{var_name}_end")
|
| 756 |
+
self.emit([end_name], "Concat", ends, [axis_0_attr])
|
| 757 |
+
|
| 758 |
+
axes_name = self.generate_unique_name(f"{var_name}_axis")
|
| 759 |
+
self.emit([axes_name], "Concat", axes, [axis_0_attr])
|
| 760 |
+
|
| 761 |
+
steps_name = self.generate_unique_name(f"{var_name}_step")
|
| 762 |
+
self.emit([steps_name], "Concat", steps, [axis_0_attr])
|
| 763 |
+
else:
|
| 764 |
+
start_name = starts[0]
|
| 765 |
+
end_name = ends[0]
|
| 766 |
+
axes_name = axes[0]
|
| 767 |
+
steps_name = steps[0]
|
| 768 |
+
|
| 769 |
+
if squeezed_axes:
|
| 770 |
+
sliced_name = self.generate_unique_name(f"{var_name}_sliced")
|
| 771 |
+
self.emit(
|
| 772 |
+
[sliced_name],
|
| 773 |
+
"Slice",
|
| 774 |
+
[var_name, start_name, end_name, axes_name, steps_name],
|
| 775 |
+
)
|
| 776 |
+
squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info)
|
| 777 |
+
|
| 778 |
+
if non_scalar_indices: # use temporary to store result of squeeze
|
| 779 |
+
result = self.generate_unique_name(f"{var_name}_squeezed")
|
| 780 |
+
else: # store squeezed result in final target
|
| 781 |
+
result = target
|
| 782 |
+
|
| 783 |
+
self.emit([result], "Squeeze", [sliced_name, squeezed_axes])
|
| 784 |
+
else:
|
| 785 |
+
if non_scalar_indices: # use temporary to store result of Slice
|
| 786 |
+
result = self.generate_unique_name(f"{var_name}_sliced")
|
| 787 |
+
else: # store result of Slice in final target
|
| 788 |
+
result = target
|
| 789 |
+
slice_inputs = [var_name, start_name, end_name, axes_name, steps_name]
|
| 790 |
+
self.emit([result], "Slice", slice_inputs)
|
| 791 |
+
else:
|
| 792 |
+
result = var_name
|
| 793 |
+
non_scalar_indices.extend(scalar_indices)
|
| 794 |
+
if non_scalar_indices:
|
| 795 |
+
last_axis, _ = non_scalar_indices[-1]
|
| 796 |
+
else:
|
| 797 |
+
# TODO(justinchuby): Clarify what last_axis should be when non_scalar_indices is False
|
| 798 |
+
last_axis = None
|
| 799 |
+
for axis, index_expr in non_scalar_indices:
|
| 800 |
+
index_value = self._translate_expr(index_expr)
|
| 801 |
+
axis_attr = self._make_onnx_attr("axis", axis)
|
| 802 |
+
# use Gather to perform indexing
|
| 803 |
+
# Assign gathered value to either temporary or final target
|
| 804 |
+
if axis != last_axis: # use temporary to store result of Gather
|
| 805 |
+
gathered = self.generate_unique_name(f"{var_name}_axis_{axis}")
|
| 806 |
+
else: # store result of Gather in final target
|
| 807 |
+
gathered = target
|
| 808 |
+
self.emit([gathered], "Gather", [str(result), index_value], [axis_attr])
|
| 809 |
+
result = gathered
|
| 810 |
+
|
| 811 |
+
return Variable(result)
|
| 812 |
+
|
| 813 |
+
def _translate_call_expr(self, node: ast.Call):
|
| 814 |
+
"""Translates a call-expression."""
|
| 815 |
+
callee = self._translate_callee_expr(node.func)
|
| 816 |
+
param_schemas = callee.param_schemas()
|
| 817 |
+
# If the callee's schema is available, we use it to determine the inputs and attributes.
|
| 818 |
+
# Otherwise, we map named arguments to attributes and positional arguments to inputs.
|
| 819 |
+
if param_schemas:
|
| 820 |
+
kwargs = {x.arg: x.value for x in node.keywords}
|
| 821 |
+
args, attrs = param_manipulation.separate_input_attributes_from_arguments(
|
| 822 |
+
param_schemas, node.args, kwargs, fill_defaults=False
|
| 823 |
+
)
|
| 824 |
+
args = [self._translate_opt_expr(x) for x in args]
|
| 825 |
+
attrs = [
|
| 826 |
+
self._translate_attr(x, y, callee.op_schema.attributes[x])
|
| 827 |
+
for x, y in attrs.items()
|
| 828 |
+
]
|
| 829 |
+
else:
|
| 830 |
+
args = [self._translate_opt_expr(x) for x in node.args]
|
| 831 |
+
attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords]
|
| 832 |
+
args = autocast.static_cast_inputs(self, callee.op_schema, args)
|
| 833 |
+
|
| 834 |
+
# In ONNX, there is no way to explicitly specify a None value for an attribute.
|
| 835 |
+
# Instead, the attribute must be omitted from the attribute list.
|
| 836 |
+
# Hence, we do not create an attribute-proto if the value is None.
|
| 837 |
+
attrs = [attr for attr in attrs if attr is not None]
|
| 838 |
+
return callee, args, attrs
|
| 839 |
+
|
| 840 |
+
def _cast_like_binary_expression(self, op, left, right):
|
| 841 |
+
schema = op.op_schema
|
| 842 |
+
return autocast.static_cast_inputs(self, schema, (left, right))
|
| 843 |
+
|
| 844 |
+
def _translate_binary_op_expr(self, node: ast.BinOp):
|
| 845 |
+
op = type(node.op)
|
| 846 |
+
if op not in primop_map:
|
| 847 |
+
raise ValueError(self._message(node, f"Unsupported operator {op!r}."))
|
| 848 |
+
|
| 849 |
+
attr = []
|
| 850 |
+
if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right):
|
| 851 |
+
# specific case X % f where f is a float.
|
| 852 |
+
# attribute fmod=1 is added in that case.
|
| 853 |
+
cst = self._eval_constant_expr(node.right)
|
| 854 |
+
if isinstance(cst, float):
|
| 855 |
+
attr = [self._make_onnx_attr("fmod", 1)]
|
| 856 |
+
|
| 857 |
+
op = values.Op(self.default_opset, primop_map[op])
|
| 858 |
+
left, right = self._cast_like_binary_expression(
|
| 859 |
+
op, self._translate_expr(node.left), self._translate_expr(node.right)
|
| 860 |
+
)
|
| 861 |
+
return op, [left, right], attr
|
| 862 |
+
|
| 863 |
+
def _translate_unary_op_expr(self, node):
|
| 864 |
+
op = type(node.op)
|
| 865 |
+
if op not in primop_map:
|
| 866 |
+
raise ValueError(self._message(node, self).msg(f"Unsupported operator {op!r}."))
|
| 867 |
+
if self._is_constant_expr(node.operand):
|
| 868 |
+
# This function changed the constant node.operand
|
| 869 |
+
# and returns it. The function calling this one
|
| 870 |
+
# should intercept this call and replace node
|
| 871 |
+
# by node.operand.
|
| 872 |
+
# This mechanism does not handle somthing like `(-(-5))`.
|
| 873 |
+
if hasattr(node.operand, "value"):
|
| 874 |
+
# python 3.8+
|
| 875 |
+
val = node.operand.value
|
| 876 |
+
else:
|
| 877 |
+
raise TypeError(
|
| 878 |
+
f"Unable to guess constant value from type {type(node.operand)!r} "
|
| 879 |
+
f"and attributes {dir(node.operand)!r}."
|
| 880 |
+
)
|
| 881 |
+
if op == ast.USub:
|
| 882 |
+
cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset)
|
| 883 |
+
return self._translate_expr(cst)
|
| 884 |
+
if op == ast.UAdd:
|
| 885 |
+
return self._translate_expr(node.operand)
|
| 886 |
+
opname = primop_map[op]
|
| 887 |
+
operand = self._translate_expr(node.operand)
|
| 888 |
+
return values.Op(self.default_opset, opname), [operand], []
|
| 889 |
+
|
| 890 |
+
def _translate_compare_expr(self, node):
|
| 891 |
+
# TODO: handle multiple comparisons in one expression
|
| 892 |
+
assert len(node.ops) == 1
|
| 893 |
+
assert len(node.comparators) == 1
|
| 894 |
+
op = type(node.ops[0])
|
| 895 |
+
if op not in primop_map:
|
| 896 |
+
raise ValueError(self._message(node, f"Unsupported operator {op!r}."))
|
| 897 |
+
opname = primop_map[op]
|
| 898 |
+
left = self._translate_expr(node.left)
|
| 899 |
+
right = self._translate_expr(node.comparators[0])
|
| 900 |
+
|
| 901 |
+
# NotEqual is not a standard ONNX op, and needs to be translated into
|
| 902 |
+
# an Equal op/node followed by a Not op/node.
|
| 903 |
+
op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal")
|
| 904 |
+
left, right = self._cast_like_binary_expression(op, left, right)
|
| 905 |
+
if opname == "NotEqual":
|
| 906 |
+
tmp = self.generate_unique_name()
|
| 907 |
+
self.emit([tmp], op, [left, right])
|
| 908 |
+
not_op = values.Op(self.default_opset, "Not")
|
| 909 |
+
return not_op, [tmp], []
|
| 910 |
+
|
| 911 |
+
return op, [left, right], []
|
| 912 |
+
|
| 913 |
+
def _translate_name_expr(self, node: ast.Name) -> Variable:
|
| 914 |
+
return self._py_var_to_onnx_var(node.id, self._source_of(node))
|
| 915 |
+
|
| 916 |
+
# pylint: disable=inconsistent-return-statements
|
| 917 |
+
def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset:
|
| 918 |
+
"""Return an Opset"""
|
| 919 |
+
if isinstance(node, ast.Name):
|
| 920 |
+
val = self._lookup(node.id, self._source_of(node), raise_exception=False)
|
| 921 |
+
if isinstance(val, values.Opset):
|
| 922 |
+
return val
|
| 923 |
+
self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.")
|
| 924 |
+
elif isinstance(node, ast.Attribute):
|
| 925 |
+
self.fail(node, "Nested module unimplemented.") # TODO
|
| 926 |
+
else:
|
| 927 |
+
self.fail(node, "Invalid opset expression.")
|
| 928 |
+
|
| 929 |
+
# pylint: enable=inconsistent-return-statements
|
| 930 |
+
def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable=R1710
|
| 931 |
+
"""Return an Op"""
|
| 932 |
+
if isinstance(node, ast.Attribute):
|
| 933 |
+
module = self._translate_opset_expr(node.value)
|
| 934 |
+
self._set_default_opset(module, node)
|
| 935 |
+
opname = node.attr
|
| 936 |
+
if opname in module:
|
| 937 |
+
return values.Op(module, node.attr)
|
| 938 |
+
return values.Op(module, node.attr)
|
| 939 |
+
if isinstance(node, ast.Name):
|
| 940 |
+
function_name = node.id
|
| 941 |
+
found = self._lookup(function_name, self._source_of(node), raise_exception=False)
|
| 942 |
+
if isinstance(found, onnxscript.OnnxFunction):
|
| 943 |
+
self._current_fn.add_called_function(found)
|
| 944 |
+
return found
|
| 945 |
+
if isinstance(found, values.Op):
|
| 946 |
+
return found
|
| 947 |
+
if not found:
|
| 948 |
+
if function_name not in self.default_opset:
|
| 949 |
+
warn(
|
| 950 |
+
f"Unknown function name {function_name!r}. "
|
| 951 |
+
f"The ONNX graph may not work."
|
| 952 |
+
)
|
| 953 |
+
return values.Op(self.default_opset, function_name)
|
| 954 |
+
self.fail(node, "Invalid callee")
|
| 955 |
+
|
| 956 |
+
def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
|
| 957 |
+
"""Statement translation: A single Python statement is mapped into a
|
| 958 |
+
sequence of IR statements.
|
| 959 |
+
"""
|
| 960 |
+
if isinstance(node, ast.Assign):
|
| 961 |
+
return self._translate_assign_stmt(node)
|
| 962 |
+
if isinstance(node, ast.AnnAssign):
|
| 963 |
+
return self._translate_assign_stmt(node)
|
| 964 |
+
if isinstance(node, ast.Return):
|
| 965 |
+
if index_of_stmt is not None:
|
| 966 |
+
return self._translate_return_stmt(node)
|
| 967 |
+
raise ValueError(
|
| 968 |
+
self._message(
|
| 969 |
+
node, "Return statements are not permitted inside control-flow statements."
|
| 970 |
+
)
|
| 971 |
+
)
|
| 972 |
+
if isinstance(node, ast.If):
|
| 973 |
+
return self._translate_if_stmt(node)
|
| 974 |
+
if isinstance(node, (ast.For, ast.While)):
|
| 975 |
+
return self._translate_loop_stmt(node)
|
| 976 |
+
if ast_utils.is_doc_string(node):
|
| 977 |
+
if index_of_stmt == 0:
|
| 978 |
+
return self._translate_docstring(node)
|
| 979 |
+
return None
|
| 980 |
+
if isinstance(node, ast.FunctionDef):
|
| 981 |
+
return self._translate_nested_function_def(node)
|
| 982 |
+
if ast_utils.is_print_call(node):
|
| 983 |
+
return None
|
| 984 |
+
raise ValueError(self._message(node, f"Unsupported statement type '{type(node)!r}'."))
|
| 985 |
+
|
| 986 |
+
def _translate_assign_stmt(self, stmt: Union[ast.Assign, ast.AnnAssign]) -> None:
|
| 987 |
+
def assign(lhs: ast.AST, rhs: ast.AST) -> None:
|
| 988 |
+
if isinstance(lhs, ast.Name):
|
| 989 |
+
# Assignments of the form "x = SomeExpression"
|
| 990 |
+
info = self._source_of(lhs)
|
| 991 |
+
lhs = lhs.id
|
| 992 |
+
t = self._translate_expr(rhs, lhs).name
|
| 993 |
+
if isinstance(stmt, ast.AnnAssign):
|
| 994 |
+
typeinfo = self._eval_constant_expr(stmt.annotation)
|
| 995 |
+
else:
|
| 996 |
+
typeinfo = None
|
| 997 |
+
var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo)
|
| 998 |
+
self._bind(lhs, var)
|
| 999 |
+
elif isinstance(lhs, ast.Tuple):
|
| 1000 |
+
# Assignments of the form "x, y, z = op.SomeOp(...)"
|
| 1001 |
+
if not isinstance(rhs, ast.Call):
|
| 1002 |
+
self.fail(
|
| 1003 |
+
rhs,
|
| 1004 |
+
f"RHS must be a Call expression for unpacking, found: '{type(rhs)!r}'",
|
| 1005 |
+
)
|
| 1006 |
+
callee, inputs, attrs = self._translate_call_expr(rhs)
|
| 1007 |
+
|
| 1008 |
+
def generate_onnx_name(x: ast.AST):
|
| 1009 |
+
if not isinstance(x, ast.Name):
|
| 1010 |
+
self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'")
|
| 1011 |
+
onnx_name = self.generate_unique_name(x.id)
|
| 1012 |
+
self._bind(
|
| 1013 |
+
x.id,
|
| 1014 |
+
values.Dynamic(
|
| 1015 |
+
onnx_name, values.DynamicKind.Intermediate, self._source_of(x)
|
| 1016 |
+
),
|
| 1017 |
+
)
|
| 1018 |
+
return onnx_name
|
| 1019 |
+
|
| 1020 |
+
outputs = [generate_onnx_name(x) for x in lhs.elts]
|
| 1021 |
+
self.emit(outputs, callee, inputs, attrs)
|
| 1022 |
+
else:
|
| 1023 |
+
self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'")
|
| 1024 |
+
|
| 1025 |
+
if isinstance(stmt, ast.Assign):
|
| 1026 |
+
targets = stmt.targets
|
| 1027 |
+
else:
|
| 1028 |
+
targets = [stmt.target]
|
| 1029 |
+
if len(targets) != 1:
|
| 1030 |
+
# Assignments of the form "x = y = SomeExpression"
|
| 1031 |
+
self.fail(stmt, "Multi-assignment not supported.")
|
| 1032 |
+
lhs = targets[0]
|
| 1033 |
+
rhs = stmt.value
|
| 1034 |
+
if isinstance(rhs, ast.Tuple):
|
| 1035 |
+
# Assignments of the form "... = Expression1, Expression2"
|
| 1036 |
+
if not isinstance(lhs, ast.Tuple):
|
| 1037 |
+
# Assignments of the form "single_var = Expression1, Expression2".
|
| 1038 |
+
# We do not support tuple-typed variables.
|
| 1039 |
+
self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.")
|
| 1040 |
+
# Parallel assignments of the form "x, y = Expression1, Expression2"
|
| 1041 |
+
if len(lhs.elts) != len(rhs.elts):
|
| 1042 |
+
self.fail(
|
| 1043 |
+
stmt, "Expected same number of elements on lhs and rhs of assignments."
|
| 1044 |
+
)
|
| 1045 |
+
for p, r in zip(lhs.elts, rhs.elts):
|
| 1046 |
+
assign(p, r)
|
| 1047 |
+
else:
|
| 1048 |
+
assign(lhs, rhs)
|
| 1049 |
+
|
| 1050 |
+
def _translate_return_stmt(self, stmt: ast.Return) -> None:
|
| 1051 |
+
def check_num_outputs(n):
|
| 1052 |
+
if self.returntype is not None:
|
| 1053 |
+
if n != len(self.returntype):
|
| 1054 |
+
raise SyntaxError(
|
| 1055 |
+
self._message(
|
| 1056 |
+
stmt,
|
| 1057 |
+
f"Mismatch in number of return values and types. Keyword "
|
| 1058 |
+
f"'return' cannot be used in a subgraph (test, loop). "
|
| 1059 |
+
f"returntype is {self.returntype!r}, num_outputs={n!r}.",
|
| 1060 |
+
)
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
def ret(exp, i, suffix):
|
| 1064 |
+
preferred_name = f"return_val{suffix}"
|
| 1065 |
+
return_var = self._translate_expr(exp, preferred_name).name
|
| 1066 |
+
val = self._lookup(return_var, self._source_of(exp), False)
|
| 1067 |
+
if val and val.kind == values.DynamicKind.Input:
|
| 1068 |
+
# In ONNX, a graph-input cannot be an output of the graph.
|
| 1069 |
+
# We need to insert a copy.
|
| 1070 |
+
return_var = self._emit_copy(return_var, preferred_name)
|
| 1071 |
+
for prev_output in self._current_fn.outputs:
|
| 1072 |
+
if prev_output.name == return_var:
|
| 1073 |
+
# ONNX does not allow duplicate output names.
|
| 1074 |
+
return_var = self._emit_copy(return_var, f"{return_var}_copy")
|
| 1075 |
+
break
|
| 1076 |
+
if self.returntype is None:
|
| 1077 |
+
t = None
|
| 1078 |
+
else:
|
| 1079 |
+
t = self.returntype[i]
|
| 1080 |
+
self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt))
|
| 1081 |
+
return return_var
|
| 1082 |
+
|
| 1083 |
+
val = stmt.value
|
| 1084 |
+
assert val is not None, "Return statement without return-value not supported."
|
| 1085 |
+
if isinstance(val, ast.Tuple):
|
| 1086 |
+
check_num_outputs(len(val.elts))
|
| 1087 |
+
return [ret(exp, i, str(i)) for i, exp in enumerate(val.elts)]
|
| 1088 |
+
check_num_outputs(1)
|
| 1089 |
+
return ret(val, 0, "")
|
| 1090 |
+
|
| 1091 |
+
def _translate_if_stmt(self, stmt: ast.If) -> None:
|
| 1092 |
+
if hasattr(stmt, "live_out"):
|
| 1093 |
+
live_defs = list(
|
| 1094 |
+
stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message))
|
| 1095 |
+
)
|
| 1096 |
+
else:
|
| 1097 |
+
live_defs = list(analysis.assigned_vars(stmt, self._message))
|
| 1098 |
+
test = self._translate_expr(stmt.test, "cond").name
|
| 1099 |
+
lineno = self._source_of(stmt).lineno
|
| 1100 |
+
thenGraph, sub_fct_then = self._translate_block(
|
| 1101 |
+
stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt
|
| 1102 |
+
)
|
| 1103 |
+
thenAttr = self._make_onnx_attr("then_branch", thenGraph)
|
| 1104 |
+
elseGraph, sub_fct_else = self._translate_block(
|
| 1105 |
+
stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt
|
| 1106 |
+
)
|
| 1107 |
+
elseAttr = self._make_onnx_attr("else_branch", elseGraph)
|
| 1108 |
+
|
| 1109 |
+
def rename(x):
|
| 1110 |
+
r = self.generate_unique_name(x)
|
| 1111 |
+
self._bind(
|
| 1112 |
+
x,
|
| 1113 |
+
values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)),
|
| 1114 |
+
)
|
| 1115 |
+
return r
|
| 1116 |
+
|
| 1117 |
+
# no break condition
|
| 1118 |
+
renamed = [rename(x) for x in live_defs]
|
| 1119 |
+
if not renamed:
|
| 1120 |
+
self.fail(stmt, "A subgraph for a test do not have any output variable.")
|
| 1121 |
+
|
| 1122 |
+
sub_functions = {}
|
| 1123 |
+
sub_functions.update(sub_fct_then)
|
| 1124 |
+
sub_functions.update(sub_fct_else)
|
| 1125 |
+
if renamed == [test]:
|
| 1126 |
+
self.fail(stmt, f"Input and output cannot be the same {renamed!r}.")
|
| 1127 |
+
self.emit(
|
| 1128 |
+
renamed,
|
| 1129 |
+
values.Op(self.default_opset, "If"),
|
| 1130 |
+
[test],
|
| 1131 |
+
[thenAttr, elseAttr],
|
| 1132 |
+
sub_functions=sub_functions,
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
|
| 1136 |
+
# loop-variable
|
| 1137 |
+
if isinstance(loop_stmt, ast.For):
|
| 1138 |
+
if not isinstance(loop_stmt.target, ast.Name):
|
| 1139 |
+
self.fail(loop_stmt, "For loop target must be a single variable.")
|
| 1140 |
+
p_loop_var = loop_stmt.target.id
|
| 1141 |
+
# iter
|
| 1142 |
+
iter = loop_stmt.iter
|
| 1143 |
+
assert isinstance(iter, ast.Call), "Loop bound not a call."
|
| 1144 |
+
if not isinstance(iter.func, ast.Name):
|
| 1145 |
+
self.fail(loop_stmt, f"Unsupported loop bound {iter.func!r}.")
|
| 1146 |
+
if iter.func.id != "range":
|
| 1147 |
+
self.fail(
|
| 1148 |
+
loop_stmt, "Unsupported loop bound, only function 'range' is allowed."
|
| 1149 |
+
)
|
| 1150 |
+
if not iter.args or len(iter.args) != 1:
|
| 1151 |
+
self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
|
| 1152 |
+
assert not iter.keywords, "Unsupported loop bound."
|
| 1153 |
+
o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name
|
| 1154 |
+
o_cond_var = self.generate_unique_name("cond_in")
|
| 1155 |
+
i_cond_var = o_cond_var
|
| 1156 |
+
cond_while = None
|
| 1157 |
+
o_loop_condition = "" # No condition for a for loop.
|
| 1158 |
+
elif isinstance(loop_stmt, ast.While):
|
| 1159 |
+
test = loop_stmt.test
|
| 1160 |
+
if not isinstance(test, ast.Name):
|
| 1161 |
+
self.fail(
|
| 1162 |
+
loop_stmt,
|
| 1163 |
+
"Unexpected condition type {type(loop_stmt)!r} for a while loop, "
|
| 1164 |
+
"it should be 'while <condition_name>:'.",
|
| 1165 |
+
)
|
| 1166 |
+
p_loop_var = "infinite_loop"
|
| 1167 |
+
o_loop_bound = ""
|
| 1168 |
+
i_cond_var = test.id
|
| 1169 |
+
cond_while = test.id
|
| 1170 |
+
o_cond_var = None
|
| 1171 |
+
o_loop_condition = self._translate_name_expr(test)
|
| 1172 |
+
# we need to go through all the instructions to see
|
| 1173 |
+
# which instruction defines the condition test.id
|
| 1174 |
+
else:
|
| 1175 |
+
self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.")
|
| 1176 |
+
# analyze loop body
|
| 1177 |
+
exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message)
|
| 1178 |
+
vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message)
|
| 1179 |
+
loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out)
|
| 1180 |
+
scan_outputs = set() # TODO
|
| 1181 |
+
outputs = list(loop_state_vars | scan_outputs)
|
| 1182 |
+
|
| 1183 |
+
# loop-condition:
|
| 1184 |
+
# o_loop_condition = self._emit_const(True, "true", self._source_of(loop_stmt))
|
| 1185 |
+
|
| 1186 |
+
# build loop_body
|
| 1187 |
+
self._enter_scope("loop_body", loop_stmt)
|
| 1188 |
+
o_loop_var = self.generate_unique_name(p_loop_var)
|
| 1189 |
+
self.ir_builder.add_input(
|
| 1190 |
+
self._current_fn,
|
| 1191 |
+
o_loop_var,
|
| 1192 |
+
onnx_types.INT64,
|
| 1193 |
+
self._source_of(loop_stmt),
|
| 1194 |
+
)
|
| 1195 |
+
self._bind(
|
| 1196 |
+
p_loop_var,
|
| 1197 |
+
values.Dynamic(o_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)),
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
self.ir_builder.add_input(
|
| 1201 |
+
self._current_fn,
|
| 1202 |
+
i_cond_var,
|
| 1203 |
+
onnx_types.BOOL,
|
| 1204 |
+
self._source_of(loop_stmt),
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
for pv in loop_state_vars:
|
| 1208 |
+
ov = self.generate_unique_name(pv)
|
| 1209 |
+
# TODO: retrieve the annotation for variable pv is any is specified.
|
| 1210 |
+
# typeinfo = self._eval_constant_expr(pv.annotation)
|
| 1211 |
+
typeinfo = None
|
| 1212 |
+
self.ir_builder.add_input(
|
| 1213 |
+
self._current_fn, ov, typeinfo, self._source_of(loop_stmt)
|
| 1214 |
+
)
|
| 1215 |
+
self._bind(
|
| 1216 |
+
pv,
|
| 1217 |
+
values.Dynamic(ov, values.DynamicKind.Loop, self._source_of(loop_stmt)),
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
condition_name = None
|
| 1221 |
+
operator_name = "Identity"
|
| 1222 |
+
for i, s in enumerate(loop_stmt.body):
|
| 1223 |
+
# We first need to intercept a break instruction in test block.
|
| 1224 |
+
# It must be something like `if <condition_name>: break`.
|
| 1225 |
+
# This instruction must be the last of the loop body.
|
| 1226 |
+
if isinstance(s, ast.If) and len(s.body) == 1 and isinstance(s.body[0], ast.Break):
|
| 1227 |
+
if not isinstance(s.test, ast.Name):
|
| 1228 |
+
self.fail(
|
| 1229 |
+
s,
|
| 1230 |
+
f"Instruction break can be introduced with test but it must be "
|
| 1231 |
+
f"if <condition>: break. However condition is of type "
|
| 1232 |
+
f"{type(s.test)!r}.",
|
| 1233 |
+
)
|
| 1234 |
+
if i != len(loop_stmt.body) - 1:
|
| 1235 |
+
self.fail(s, "Instruction break must be the last one of the loop.")
|
| 1236 |
+
|
| 1237 |
+
current_scope = self._current_scope()
|
| 1238 |
+
if s.test.id not in current_scope:
|
| 1239 |
+
self.fail(
|
| 1240 |
+
loop_stmt,
|
| 1241 |
+
f"Unable to find condition variable {s.test.id!r} in known "
|
| 1242 |
+
f"variables {list(current_scope)!r}.",
|
| 1243 |
+
)
|
| 1244 |
+
condition_name = current_scope[s.test.id].value
|
| 1245 |
+
operator_name = "Not"
|
| 1246 |
+
continue
|
| 1247 |
+
self._translate_stmt(s)
|
| 1248 |
+
|
| 1249 |
+
o_cond_out = self.generate_unique_name("cond_out")
|
| 1250 |
+
|
| 1251 |
+
if cond_while is not None:
|
| 1252 |
+
# Loop while
|
| 1253 |
+
current_scope = self._current_scope()
|
| 1254 |
+
if cond_while not in current_scope:
|
| 1255 |
+
self.fail(
|
| 1256 |
+
loop_stmt,
|
| 1257 |
+
f"Unable to find condition variable {cond_while!r} in known "
|
| 1258 |
+
f"variables {list(current_scope)!r}.",
|
| 1259 |
+
)
|
| 1260 |
+
o_cond_var = current_scope[cond_while].value
|
| 1261 |
+
|
| 1262 |
+
self.emit(
|
| 1263 |
+
[o_cond_out],
|
| 1264 |
+
values.Op(self.default_opset, operator_name),
|
| 1265 |
+
[condition_name or o_cond_var],
|
| 1266 |
+
[],
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
self.ir_builder.add_output(
|
| 1270 |
+
self._current_fn,
|
| 1271 |
+
o_cond_out,
|
| 1272 |
+
onnx_types.BOOL,
|
| 1273 |
+
self._source_of(loop_stmt),
|
| 1274 |
+
)
|
| 1275 |
+
for pv in loop_state_vars:
|
| 1276 |
+
ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name
|
| 1277 |
+
if ov not in self._current_fn.assigned_names:
|
| 1278 |
+
# When converting the loop-body into a graph, we need to handle
|
| 1279 |
+
# identity assignments of the form "x = y" inside the loop body
|
| 1280 |
+
# specially if y represents a value computed outside the loop body.
|
| 1281 |
+
# In this case, we create a copy of y, treating the statement as
|
| 1282 |
+
# shorthand for "x = op.Identity(y)".
|
| 1283 |
+
ov = self._emit_copy(ov, pv)
|
| 1284 |
+
# TODO: retrieve variable type for the annotation if any.
|
| 1285 |
+
typeinfo = None
|
| 1286 |
+
self.ir_builder.add_output(
|
| 1287 |
+
self._current_fn, ov, typeinfo, self._source_of(loop_stmt)
|
| 1288 |
+
)
|
| 1289 |
+
body = self._exit_scope()
|
| 1290 |
+
inputs = [o_loop_bound, o_loop_condition] + [
|
| 1291 |
+
self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name
|
| 1292 |
+
for pv in loop_state_vars
|
| 1293 |
+
]
|
| 1294 |
+
graph, sub_functions = body.to_graph_and_functions()
|
| 1295 |
+
attrs = [self._make_onnx_attr("body", graph)]
|
| 1296 |
+
info = self._source_of(loop_stmt)
|
| 1297 |
+
|
| 1298 |
+
def rename(x):
|
| 1299 |
+
r = self.generate_unique_name(x)
|
| 1300 |
+
self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info))
|
| 1301 |
+
return r
|
| 1302 |
+
|
| 1303 |
+
onnx_outputs = [rename(x) for x in outputs]
|
| 1304 |
+
self.emit(
|
| 1305 |
+
onnx_outputs,
|
| 1306 |
+
"Loop",
|
| 1307 |
+
inputs,
|
| 1308 |
+
attrs,
|
| 1309 |
+
sub_functions=sub_functions,
|
| 1310 |
+
)
|
| 1311 |
+
|
| 1312 |
+
def _translate_block(
|
| 1313 |
+
self,
|
| 1314 |
+
stmts: Sequence[ast.stmt],
|
| 1315 |
+
name: str,
|
| 1316 |
+
live_defs: Sequence[str],
|
| 1317 |
+
parent_stmt: ast.stmt,
|
| 1318 |
+
):
|
| 1319 |
+
"""Translation of a statement-block to GraphProto attribute."""
|
| 1320 |
+
info_stmt = stmts[0] if len(stmts) > 0 else parent_stmt
|
| 1321 |
+
source = self._source_of(info_stmt)
|
| 1322 |
+
self._enter_scope(name, None)
|
| 1323 |
+
for s in stmts:
|
| 1324 |
+
self._translate_stmt(s)
|
| 1325 |
+
for pvar in live_defs:
|
| 1326 |
+
if pvar in self._current_scope():
|
| 1327 |
+
pv_val = self._current_scope()[pvar]
|
| 1328 |
+
output = self._to_onnx_var(pv_val, pvar).name
|
| 1329 |
+
if output not in self._current_fn.assigned_names:
|
| 1330 |
+
# To return an outer-scope variable, an ONNX Graph has to
|
| 1331 |
+
# use an explicit copy via Identity.
|
| 1332 |
+
output = self._emit_copy(output, pvar)
|
| 1333 |
+
self.ir_builder.add_output(
|
| 1334 |
+
self._current_fn,
|
| 1335 |
+
output,
|
| 1336 |
+
pv_val.typeinfo,
|
| 1337 |
+
source,
|
| 1338 |
+
)
|
| 1339 |
+
else:
|
| 1340 |
+
pv_val = None
|
| 1341 |
+
for scope in self._locals: # TODO: skip _current_scope
|
| 1342 |
+
if pvar in scope:
|
| 1343 |
+
pv_val = scope[pvar]
|
| 1344 |
+
break
|
| 1345 |
+
if pv_val is None:
|
| 1346 |
+
self.fail(
|
| 1347 |
+
stmts[0],
|
| 1348 |
+
f"Variable {pvar} is not assigned a value along a conditional "
|
| 1349 |
+
f"branch, known variables: {list(self._locals)}.",
|
| 1350 |
+
)
|
| 1351 |
+
# introduce a copy
|
| 1352 |
+
ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar).name, pvar)
|
| 1353 |
+
|
| 1354 |
+
# TODO: retrieve the annotation if any.
|
| 1355 |
+
typeinfo = None
|
| 1356 |
+
self.ir_builder.add_output(self._current_fn, ovar, typeinfo, source)
|
| 1357 |
+
graph = self._exit_scope()
|
| 1358 |
+
return graph.to_graph_and_functions()
|
| 1359 |
+
|
| 1360 |
+
def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
|
| 1361 |
+
"""Translate a nested function definition."""
|
| 1362 |
+
self._enter_scope(fn.name, fn)
|
| 1363 |
+
self._translate_function_def_common(fn)
|
| 1364 |
+
function_ir = self._exit_scope()
|
| 1365 |
+
outer_scope_vars = analysis.outer_scope_variables(fn, self._message)
|
| 1366 |
+
function_ir.outer_scope_variables = [
|
| 1367 |
+
(var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars
|
| 1368 |
+
]
|
| 1369 |
+
self._bind(fn.name, function_ir)
|
| 1370 |
+
# TODO: Does not yet handle nested functions within nested functions.
|
| 1371 |
+
self._current_fn.add_nested_function(function_ir)
|
| 1372 |
+
|
| 1373 |
+
def _translate_function_signature_common(
|
| 1374 |
+
self, fn: ast.FunctionDef
|
| 1375 |
+
) -> irbuilder.IRFunction:
|
| 1376 |
+
"""Translate a function signature (top-level or nested)."""
|
| 1377 |
+
args = fn.args
|
| 1378 |
+
if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg:
|
| 1379 |
+
warn(f"{fn.name}: Unsupported feature in function signature.")
|
| 1380 |
+
for i, x in enumerate(args.args):
|
| 1381 |
+
arg_with_default_start_index = len(args.args) - len(args.defaults)
|
| 1382 |
+
if args.defaults and i >= arg_with_default_start_index:
|
| 1383 |
+
default_value = self._eval_constant_expr(
|
| 1384 |
+
args.defaults[i - arg_with_default_start_index]
|
| 1385 |
+
)
|
| 1386 |
+
else:
|
| 1387 |
+
default_value = None
|
| 1388 |
+
if x.annotation:
|
| 1389 |
+
typeinfo = self._eval_constant_expr(x.annotation)
|
| 1390 |
+
if not ta.is_valid_type(typeinfo):
|
| 1391 |
+
self.warn(
|
| 1392 |
+
x.annotation,
|
| 1393 |
+
f"Unsupported type annotation for argument {x.arg}.",
|
| 1394 |
+
)
|
| 1395 |
+
typeinfo = None
|
| 1396 |
+
else:
|
| 1397 |
+
# The code can only be exported as a function.
|
| 1398 |
+
typeinfo = None
|
| 1399 |
+
if typeinfo and ta.is_attr_type(typeinfo):
|
| 1400 |
+
self.ir_builder.add_attr_parameter(
|
| 1401 |
+
self._current_fn,
|
| 1402 |
+
x.arg,
|
| 1403 |
+
ta.pytype_to_attrtype(typeinfo),
|
| 1404 |
+
default_value,
|
| 1405 |
+
)
|
| 1406 |
+
self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
|
| 1407 |
+
else:
|
| 1408 |
+
self.ir_builder.add_input(
|
| 1409 |
+
self._current_fn, x.arg, typeinfo, self._source_of(x)
|
| 1410 |
+
)
|
| 1411 |
+
self._used_vars.add(x.arg)
|
| 1412 |
+
self._bind(
|
| 1413 |
+
x.arg,
|
| 1414 |
+
values.Dynamic(x.arg, values.DynamicKind.Input, self._source_of(x)),
|
| 1415 |
+
)
|
| 1416 |
+
if fn.returns:
|
| 1417 |
+
type_annotation = self._eval_constant_expr(fn.returns)
|
| 1418 |
+
self.returntype = ta.get_return_types(type_annotation)
|
| 1419 |
+
invalid = False
|
| 1420 |
+
for t in self.returntype:
|
| 1421 |
+
if not ta.is_valid_type(t):
|
| 1422 |
+
self.warn(
|
| 1423 |
+
fn.returns,
|
| 1424 |
+
f"Unsupported type annotation for return value {t}.",
|
| 1425 |
+
)
|
| 1426 |
+
invalid = True
|
| 1427 |
+
if invalid:
|
| 1428 |
+
self.returntype = None
|
| 1429 |
+
else:
|
| 1430 |
+
self.returntype = None
|
| 1431 |
+
|
| 1432 |
+
return self._current_fn
|
| 1433 |
+
|
| 1434 |
+
def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
|
| 1435 |
+
"""Translate a function definition, including the signature and its body."""
|
| 1436 |
+
logger.debug("Converter:_translate_function_def_common:%s", fn.name)
|
| 1437 |
+
_ = self._translate_function_signature_common(fn)
|
| 1438 |
+
for i, s in enumerate(fn.body):
|
| 1439 |
+
self._translate_stmt(s, index_of_stmt=i)
|
| 1440 |
+
return self._current_fn
|
| 1441 |
+
|
| 1442 |
+
def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
|
| 1443 |
+
if isinstance(stmt, ast.FunctionDef):
|
| 1444 |
+
self._init_function_translation()
|
| 1445 |
+
if self.default_opset_ is None:
|
| 1446 |
+
opset = self._find_onnx_opset(stmt)
|
| 1447 |
+
if opset:
|
| 1448 |
+
self._set_default_opset(opset, stmt)
|
| 1449 |
+
domain = self.this_module.domain
|
| 1450 |
+
self._current_fn = self.ir_builder.new_function(stmt.name, domain, True)
|
| 1451 |
+
analysis.do_liveness_analysis(stmt, self._message)
|
| 1452 |
+
fn_ir = self._translate_function_def_common(stmt)
|
| 1453 |
+
fn_ir.debug_print()
|
| 1454 |
+
self.this_module.add_function_def(fn_ir)
|
| 1455 |
+
return fn_ir
|
| 1456 |
+
raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.")
|
| 1457 |
+
|
| 1458 |
+
def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
|
| 1459 |
+
"""Translate a (top-level) function signature."""
|
| 1460 |
+
domain = self.this_module.domain
|
| 1461 |
+
self._current_fn = self.ir_builder.new_function(fn.name, domain, True)
|
| 1462 |
+
return self._translate_function_signature_common(fn)
|
pythonProject/.venv/Lib/site-packages/onnxscript/evaluator.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import abc
|
| 6 |
+
import contextlib
|
| 7 |
+
import pprint
|
| 8 |
+
from typing import (
|
| 9 |
+
Any,
|
| 10 |
+
Callable,
|
| 11 |
+
Mapping,
|
| 12 |
+
Optional,
|
| 13 |
+
Protocol,
|
| 14 |
+
Sequence,
|
| 15 |
+
TypeVar,
|
| 16 |
+
Union,
|
| 17 |
+
runtime_checkable,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import onnx
|
| 22 |
+
import onnx.defs
|
| 23 |
+
import onnx.reference
|
| 24 |
+
from typing_extensions import TypeAlias
|
| 25 |
+
|
| 26 |
+
from onnxscript import irbuilder, onnx_opset, tensor, values
|
| 27 |
+
from onnxscript._internal import autocast, param_manipulation, utils
|
| 28 |
+
|
| 29 |
+
UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]]
|
| 30 |
+
|
| 31 |
+
EagerModeValue: TypeAlias = Union[Optional["tensor.Tensor"], Sequence["EagerModeValue"]]
|
| 32 |
+
|
| 33 |
+
ExtendedModeValue: TypeAlias = Union[
|
| 34 |
+
Optional["tensor.Tensor"],
|
| 35 |
+
Sequence["ExtendedModeValue"],
|
| 36 |
+
np.ndarray,
|
| 37 |
+
int,
|
| 38 |
+
float,
|
| 39 |
+
bool,
|
| 40 |
+
str,
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
_T = TypeVar("_T")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _adapt_to_eager_mode(inputs: ExtendedModeValue) -> tuple[EagerModeValue, bool]:
|
| 47 |
+
"""Adapts inputs into representation used by onnxscript eager mode.
|
| 48 |
+
|
| 49 |
+
This does the following transformations:
|
| 50 |
+
* It adds an onnxscript Tensor wrapper around numpy arrays, which
|
| 51 |
+
allows the use of overloaded operators like + to be controlled by onnxscript.
|
| 52 |
+
* It also provides a promotion of scalars into tensors as a convenience.
|
| 53 |
+
This is needed to complement the similar promotion supported by the
|
| 54 |
+
onnxscript converter (for example, when an attribute is promoted and used
|
| 55 |
+
as an input argument).
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
inputs: a list/tuple of inputs to an ONNX function
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
a pair (wrapped_inputs, flag) where flag indicates whether any numpy array
|
| 62 |
+
was wrapped into a Tensor.
|
| 63 |
+
"""
|
| 64 |
+
has_array = False
|
| 65 |
+
|
| 66 |
+
def adapt(input: ExtendedModeValue) -> EagerModeValue:
|
| 67 |
+
if isinstance(input, np.ndarray):
|
| 68 |
+
nonlocal has_array
|
| 69 |
+
has_array = True
|
| 70 |
+
return tensor.Tensor(input)
|
| 71 |
+
if isinstance(input, tensor.Tensor):
|
| 72 |
+
return input
|
| 73 |
+
if isinstance(input, (bool, float)):
|
| 74 |
+
return tensor.Tensor(np.array(input))
|
| 75 |
+
if isinstance(input, int):
|
| 76 |
+
return tensor.Tensor(np.array(input, dtype=np.int64))
|
| 77 |
+
if input is None:
|
| 78 |
+
return None
|
| 79 |
+
if isinstance(input, list):
|
| 80 |
+
return [adapt(elt) for elt in input]
|
| 81 |
+
if isinstance(input, tuple):
|
| 82 |
+
return tuple(adapt(elt) for elt in input)
|
| 83 |
+
raise TypeError(f"Unexpected input type {type(input)}.")
|
| 84 |
+
|
| 85 |
+
result = adapt(inputs)
|
| 86 |
+
return result, has_array
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue:
|
| 90 |
+
"""Unwraps Tensor wrapper around numpy arrays.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
output: output of an ONNX function, which can be either a single
|
| 94 |
+
onnx value or a list/tuple of onnx values.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
unwrapped output
|
| 98 |
+
"""
|
| 99 |
+
if isinstance(output, tensor.Tensor):
|
| 100 |
+
return output.value
|
| 101 |
+
if output is None:
|
| 102 |
+
return None
|
| 103 |
+
if isinstance(output, list):
|
| 104 |
+
return [_adapt_to_user_mode(elt) for elt in output]
|
| 105 |
+
if isinstance(output, tuple):
|
| 106 |
+
return tuple(_adapt_to_user_mode(elt) for elt in output)
|
| 107 |
+
if isinstance(output, np.ndarray):
|
| 108 |
+
return output
|
| 109 |
+
raise TypeError(f"Unexpected type {type(output)}.")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _unwrap_tensors_in_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]:
|
| 113 |
+
"""Unwrap tensors in a mapping to numpy arrays."""
|
| 114 |
+
new_kwargs = {}
|
| 115 |
+
for k, v in kwargs.items():
|
| 116 |
+
new_kwargs[k] = v
|
| 117 |
+
if isinstance(v, tensor.Tensor):
|
| 118 |
+
new_kwargs[k] = v.value
|
| 119 |
+
|
| 120 |
+
return new_kwargs
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@runtime_checkable
|
| 124 |
+
class Evaluator(Protocol):
|
| 125 |
+
"""Protocol for evaluating ONNX ops."""
|
| 126 |
+
|
| 127 |
+
def eval(
|
| 128 |
+
self,
|
| 129 |
+
schema: onnx.defs.OpSchema,
|
| 130 |
+
inputs: Sequence[ExtendedModeValue],
|
| 131 |
+
attributes: Mapping[str, Any],
|
| 132 |
+
):
|
| 133 |
+
"""Evaluates an ONNX op.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
schema: The OpSchema of the operator to evaluate.
|
| 137 |
+
inputs: The ONNX inputs to the op.
|
| 138 |
+
attributes: The ONNX attributes to the op.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def eval_function(
|
| 142 |
+
self,
|
| 143 |
+
function: values.OnnxFunction,
|
| 144 |
+
args: Sequence[ExtendedModeValue],
|
| 145 |
+
kwargs: Mapping[str, ExtendedModeValue],
|
| 146 |
+
):
|
| 147 |
+
"""Evaluates an OnnxFunction.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
function: The OnnxFunction to evaluate.
|
| 151 |
+
args: The positional arguments to the function.
|
| 152 |
+
kwargs: The keyword arguments to the function.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BaseEvaluator(Evaluator, abc.ABC):
|
| 157 |
+
"""Base class for evaluation of ONNX ops.
|
| 158 |
+
|
| 159 |
+
The execution of onnxscript functions in eager-mode is dispatched to an Evaluator
|
| 160 |
+
instance (or, more precisely, to the eval method of the Evaluator instance).
|
| 161 |
+
The evaluator is expected to transform the input/output/attribute representation
|
| 162 |
+
supported by onnxscript to those expected by a particular backend.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self, ignore_unknown_function_kwargs: bool = False):
|
| 166 |
+
"""Initializes a BaseEvaluator.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
ignore_unknown_function_kwargs: Whether to ignore unknown keyword arguments
|
| 170 |
+
when evaluating an OnnxFunction. This is useful when using the
|
| 171 |
+
evaluator to validate operators programmatically, where
|
| 172 |
+
additional keyword arguments that is not part of the signature
|
| 173 |
+
may be provided to the function.
|
| 174 |
+
"""
|
| 175 |
+
self._ignore_unknown_function_kwargs = ignore_unknown_function_kwargs
|
| 176 |
+
|
| 177 |
+
def eval(
|
| 178 |
+
self,
|
| 179 |
+
schema: onnx.defs.OpSchema,
|
| 180 |
+
inputs: Sequence[ExtendedModeValue],
|
| 181 |
+
attributes: Mapping[str, Any],
|
| 182 |
+
):
|
| 183 |
+
"""Evaluates an ONNX op.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
schema: The OpSchema of the operator to evaluate.
|
| 187 |
+
inputs: The ONNX inputs to the op.
|
| 188 |
+
attributes: The ONNX attributes to the op.
|
| 189 |
+
"""
|
| 190 |
+
attributes = _unwrap_tensors_in_kwargs(attributes)
|
| 191 |
+
attributes, closure = self.adapt_attributes(schema, attributes)
|
| 192 |
+
inputs = self.adapt_inputs(schema, inputs)
|
| 193 |
+
outputs = self._eval(schema, inputs, attributes, closure)
|
| 194 |
+
return self.adapt_outputs(schema, outputs)
|
| 195 |
+
|
| 196 |
+
def adapt_inputs(self, schema: onnx.defs.OpSchema, inputs: Sequence[ExtendedModeValue]):
|
| 197 |
+
"""Transform inputs to the expected format for the evaluator.
|
| 198 |
+
|
| 199 |
+
Enables some syntactic sugar, such as the use of Python scalars,
|
| 200 |
+
in a manner consistent with the translator. See autocast.py for details.
|
| 201 |
+
"""
|
| 202 |
+
return autocast.dynamic_cast_inputs(schema, inputs)
|
| 203 |
+
|
| 204 |
+
def adapt_attributes(
|
| 205 |
+
self, schema: onnx.defs.OpSchema, attributes: Mapping[str, ExtendedModeValue]
|
| 206 |
+
) -> tuple[dict[str, ExtendedModeValue], dict[str, ExtendedModeValue]]:
|
| 207 |
+
"""Transform attributes to the expected format for the evaluator.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
A closure that can be used to evaluate graph-valued attributes.
|
| 211 |
+
"""
|
| 212 |
+
use_graph_attribute = self.use_graph_attribute(schema)
|
| 213 |
+
closure: dict[Any, Any] = {}
|
| 214 |
+
adapted_attributes = {}
|
| 215 |
+
for k, v in attributes.items():
|
| 216 |
+
if isinstance(v, values.OnnxClosure):
|
| 217 |
+
if use_graph_attribute:
|
| 218 |
+
adapted_attributes[k] = v.function_ir.to_graph_proto()
|
| 219 |
+
for pyvar, onnxvar in v.function_ir.outer_scope_variables:
|
| 220 |
+
closure[onnxvar.value] = v.frame.f_locals[pyvar]
|
| 221 |
+
else:
|
| 222 |
+
adapted_attributes[k] = v.function
|
| 223 |
+
elif callable(v):
|
| 224 |
+
raise TypeError(
|
| 225 |
+
f"Error: function-valued attribute {v.__name__} has no graph_proto"
|
| 226 |
+
"attribute. Did you forget to decorate it with @graph?"
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
adapted_attributes[k] = v
|
| 230 |
+
return adapted_attributes, closure
|
| 231 |
+
|
| 232 |
+
def adapt_outputs(self, schema: onnx.defs.OpSchema, outputs: Sequence[EagerModeValue]):
|
| 233 |
+
"""Adapt evaluator's output to convention used in onnxscript.
|
| 234 |
+
|
| 235 |
+
Onnxscript uses a tuple/sequence only when number of outputs > 1.
|
| 236 |
+
"""
|
| 237 |
+
del schema # unused
|
| 238 |
+
return outputs[0] if len(outputs) == 1 else outputs
|
| 239 |
+
|
| 240 |
+
def use_graph_attribute(self, schema: onnx.defs.OpSchema):
|
| 241 |
+
del schema # unused
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
@abc.abstractmethod
|
| 245 |
+
def _eval(
|
| 246 |
+
self,
|
| 247 |
+
schema: onnx.defs.OpSchema,
|
| 248 |
+
inputs: Sequence[ExtendedModeValue],
|
| 249 |
+
attributes: Mapping[str, ExtendedModeValue],
|
| 250 |
+
closure: Mapping[str, ExtendedModeValue],
|
| 251 |
+
) -> EagerModeValue:
|
| 252 |
+
"""Evaluates an ONNX op given its schema and inputs/attributes.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
schema: The schema of the op to evaluate.
|
| 256 |
+
inputs: The ONNX inputs to the op.
|
| 257 |
+
attributes: The ONNX attributes to the op.
|
| 258 |
+
closure: The closure to use when evaluating graph-valued attributes.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def eval_function(
|
| 262 |
+
self,
|
| 263 |
+
function: values.OnnxFunction,
|
| 264 |
+
args: Sequence[ExtendedModeValue],
|
| 265 |
+
kwargs: Mapping[str, ExtendedModeValue],
|
| 266 |
+
):
|
| 267 |
+
"""Evaluates a function in eager mode.
|
| 268 |
+
|
| 269 |
+
Override this function to change the evaluator's behavior for functions.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
function: The OnnxFunction to evaluate.
|
| 273 |
+
args: The positional arguments to the function.
|
| 274 |
+
kwargs: The keyword arguments to the function.
|
| 275 |
+
"""
|
| 276 |
+
param_schemas = function.param_schemas()
|
| 277 |
+
# Split happens in the evaluator instead of the OnnxFunction __call__ method
|
| 278 |
+
# so that evaluators can control behaviors like whether to fill in default values for attributes.
|
| 279 |
+
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
|
| 280 |
+
param_schemas,
|
| 281 |
+
args,
|
| 282 |
+
kwargs,
|
| 283 |
+
fill_defaults=False,
|
| 284 |
+
allow_extra_kwargs=self._ignore_unknown_function_kwargs,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
adapted_args: list[ExtendedModeValue] = []
|
| 288 |
+
adapted_kwargs: dict[str, ExtendedModeValue] = {}
|
| 289 |
+
has_array = False
|
| 290 |
+
for arg, param_schema in tagged_args:
|
| 291 |
+
if param_schema.is_input:
|
| 292 |
+
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
|
| 293 |
+
has_array = has_array or has_array_
|
| 294 |
+
adapted_args.append(adapted_arg)
|
| 295 |
+
else:
|
| 296 |
+
adapted_args.append(arg)
|
| 297 |
+
|
| 298 |
+
for key, (arg, param_schema) in tagged_kwargs.items():
|
| 299 |
+
if param_schema.is_input:
|
| 300 |
+
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
|
| 301 |
+
has_array = has_array or has_array_
|
| 302 |
+
adapted_kwargs[key] = adapted_arg
|
| 303 |
+
else:
|
| 304 |
+
adapted_kwargs[key] = arg
|
| 305 |
+
|
| 306 |
+
result = function.function(*adapted_args, **adapted_kwargs)
|
| 307 |
+
|
| 308 |
+
# We use a heuristic to decide whether to return output values as
|
| 309 |
+
# numpy arrays or tensor.Tensors. If the function has at least one
|
| 310 |
+
# numpy array as input, we return numpy arrays. Otherwise, we return
|
| 311 |
+
# tensor.Tensors. We could use a user-specified flag to control this
|
| 312 |
+
# or explicitly track whether this is a top-level function-call or
|
| 313 |
+
# a nested function-call.
|
| 314 |
+
|
| 315 |
+
return _adapt_to_user_mode(result) if has_array else result
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# Utilities for evaluation using ORT:
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class EagerModeError(RuntimeError):
|
| 322 |
+
pass
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _rename_io(prefix, i, arg):
|
| 326 |
+
if arg is None:
|
| 327 |
+
return ""
|
| 328 |
+
return f"{prefix}{i}"
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def compute_num_outputs(
|
| 332 |
+
schema: onnx.defs.OpSchema, args: Sequence[Any], kwargs: Mapping[str, Any]
|
| 333 |
+
) -> int:
|
| 334 |
+
"""Returns the number of outputs expected."""
|
| 335 |
+
|
| 336 |
+
# TODO: Use ONNX type inference to replace the special-case handling below.
|
| 337 |
+
if schema.domain == "":
|
| 338 |
+
if schema.name == "BatchNormalization":
|
| 339 |
+
if not kwargs.get("training_mode", 0):
|
| 340 |
+
return 1
|
| 341 |
+
if schema.name == "LSTM":
|
| 342 |
+
return 3
|
| 343 |
+
if schema.name == "Split":
|
| 344 |
+
if len(args) == 1 and "num_outputs" not in kwargs:
|
| 345 |
+
raise EagerModeError(
|
| 346 |
+
"Operator Split: the number of expected outputs defines the split. "
|
| 347 |
+
"This information is unknown here."
|
| 348 |
+
)
|
| 349 |
+
if len(args) == 2: # has argument(split)
|
| 350 |
+
return len(args[1])
|
| 351 |
+
else: # no argument(split), check attribute(num_outputs)
|
| 352 |
+
return kwargs["num_outputs"]
|
| 353 |
+
if schema.name == "Scan":
|
| 354 |
+
scan_body = kwargs["body"]
|
| 355 |
+
return len(scan_body.output)
|
| 356 |
+
if schema.name == "Loop":
|
| 357 |
+
loop_body = kwargs["body"]
|
| 358 |
+
return len(loop_body.output) - 1
|
| 359 |
+
return len(schema.outputs)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _onnxscript_to_numpy_value(v):
|
| 363 |
+
"""Converts an onnxscript encoding of an ONNX value into the numpy encoding used by runtimes."""
|
| 364 |
+
if isinstance(v, tensor.Tensor):
|
| 365 |
+
return v.value
|
| 366 |
+
if isinstance(v, list):
|
| 367 |
+
return [_onnxscript_to_numpy_value(x) for x in v]
|
| 368 |
+
if isinstance(v, tuple):
|
| 369 |
+
if len(v) > 0 and type(v[0]) is int: # pylint: disable=unidiomatic-typecheck
|
| 370 |
+
return np.array(v, dtype=np.int64)
|
| 371 |
+
return np.array(v)
|
| 372 |
+
if v is None:
|
| 373 |
+
# Treated as a static-optional value.
|
| 374 |
+
# Dynamic optional None not yet supported.
|
| 375 |
+
return v
|
| 376 |
+
if isinstance(v, np.ndarray):
|
| 377 |
+
return v
|
| 378 |
+
raise TypeError(
|
| 379 |
+
f"Unexpected onnxscript value type '{type(v)}'."
|
| 380 |
+
"Valid value types are 'Tensor | list[Tensor] | None | np.ndarray | list[np.ndarray]'"
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _numpy_to_onnxscript_value(
|
| 385 |
+
v: np.ndarray | np.generic | list[np.ndarray] | list[np.generic],
|
| 386 |
+
):
|
| 387 |
+
"""Converts an ORT encoding of an ONNX value into the encoding used by onnxscript."""
|
| 388 |
+
if isinstance(v, np.ndarray):
|
| 389 |
+
# ORT may reuse buffers when the output numpy array is provided back as input.
|
| 390 |
+
# We need to make a copy to ensure that the tensor is not modified in-place.
|
| 391 |
+
return tensor.Tensor(v.copy())
|
| 392 |
+
if issubclass(type(v), np.generic):
|
| 393 |
+
# Numpy scalar types that are not ndarray
|
| 394 |
+
# https://numpy.org/doc/stable/reference/arrays.scalars.html
|
| 395 |
+
return tensor.Tensor(np.array(v))
|
| 396 |
+
if isinstance(v, list):
|
| 397 |
+
return [_numpy_to_onnxscript_value(x) for x in v]
|
| 398 |
+
if v is None:
|
| 399 |
+
raise TypeError("Dynamic optional values not yet supported.")
|
| 400 |
+
raise TypeError(
|
| 401 |
+
f"Unexpected runtime value type '{type(v)}'."
|
| 402 |
+
"Valid types are: 'np.ndarray | np.generic | list[np.ndarray] | list[np.generic]'"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _prepare_model_and_inputs_for_eager(
|
| 407 |
+
schema: onnx.defs.OpSchema,
|
| 408 |
+
args: Sequence[Any],
|
| 409 |
+
kwargs: Mapping[str, Any],
|
| 410 |
+
implicit_args: Optional[Mapping[str, Any]],
|
| 411 |
+
):
|
| 412 |
+
implicit_args = implicit_args or {}
|
| 413 |
+
# Convert input values to ORT representation-type:
|
| 414 |
+
args = [_onnxscript_to_numpy_value(x) for x in args]
|
| 415 |
+
implicit_args = {k: _onnxscript_to_numpy_value(v) for k, v in implicit_args.items()}
|
| 416 |
+
|
| 417 |
+
# Utility to convert kwarg to ONNX AttributeProto:
|
| 418 |
+
def make_attr(key: str, value: Any) -> onnx.AttributeProto:
|
| 419 |
+
def make_tensor_name() -> str:
|
| 420 |
+
return f"attr_{key}"
|
| 421 |
+
|
| 422 |
+
return autocast.pyvalue_to_onnx_attribute(
|
| 423 |
+
key, value, make_tensor_name, int(schema.attributes[key].type)
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Construct ONNX model with a single op call:
|
| 427 |
+
inputs = [_rename_io("input", i, arg) for i, arg in enumerate(args)]
|
| 428 |
+
|
| 429 |
+
num_outputs = compute_num_outputs(schema, args, kwargs)
|
| 430 |
+
outputs = [f"output{i}" for i in range(num_outputs)]
|
| 431 |
+
|
| 432 |
+
node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
|
| 433 |
+
node.attribute.extend(
|
| 434 |
+
make_attr(key, value) for key, value in kwargs.items() if value is not None
|
| 435 |
+
)
|
| 436 |
+
input_value_infos = utils.values_to_value_infos(zip(inputs, args))
|
| 437 |
+
implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
|
| 438 |
+
output_value_infos = [
|
| 439 |
+
onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
|
| 440 |
+
for name in outputs
|
| 441 |
+
]
|
| 442 |
+
|
| 443 |
+
graph = onnx.helper.make_graph( # noqa: TID251
|
| 444 |
+
[node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
|
| 445 |
+
)
|
| 446 |
+
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
|
| 447 |
+
model = onnx.helper.make_model( # noqa: TID251
|
| 448 |
+
graph,
|
| 449 |
+
opset_imports=[opset_id],
|
| 450 |
+
ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),
|
| 451 |
+
)
|
| 452 |
+
model = onnx.shape_inference.infer_shapes(model)
|
| 453 |
+
|
| 454 |
+
session_run_input = {name: arg for name, arg in zip(inputs, args) if name != ""}
|
| 455 |
+
session_run_input.update(implicit_args)
|
| 456 |
+
|
| 457 |
+
return model, session_run_input, inputs
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _call_ort(
|
| 461 |
+
schema: onnx.defs.OpSchema,
|
| 462 |
+
args: Sequence[Any],
|
| 463 |
+
kwargs: Mapping[str, Any],
|
| 464 |
+
implicit_args: Optional[Mapping[str, Any]],
|
| 465 |
+
):
|
| 466 |
+
# Delay import onnxruntime so that onnxscript can be used without
|
| 467 |
+
# installing onnxruntime.
|
| 468 |
+
import onnxruntime as ort # pylint: disable=import-outside-toplevel
|
| 469 |
+
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
|
| 470 |
+
Fail,
|
| 471 |
+
InvalidArgument,
|
| 472 |
+
InvalidGraph,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
model, session_run_input, inputs = _prepare_model_and_inputs_for_eager(
|
| 476 |
+
schema, args, kwargs, implicit_args
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
try:
|
| 480 |
+
session = ort.InferenceSession(
|
| 481 |
+
model.SerializeToString(), providers=("CPUExecutionProvider",)
|
| 482 |
+
)
|
| 483 |
+
except (Fail, InvalidGraph, InvalidArgument) as e:
|
| 484 |
+
raise EagerModeError(
|
| 485 |
+
f"Unable to create onnxruntime InferenceSession "
|
| 486 |
+
f"for executing {schema.domain}.{schema.name} op "
|
| 487 |
+
f"with onnx model\n{onnx.printer.to_text(model)}"
|
| 488 |
+
) from e
|
| 489 |
+
|
| 490 |
+
try:
|
| 491 |
+
result = session.run(None, session_run_input)
|
| 492 |
+
except (RuntimeError, Fail) as e:
|
| 493 |
+
raise EagerModeError(
|
| 494 |
+
f"Unable to execute model operator {schema.name!r} due to {e!r}"
|
| 495 |
+
f"\ninput types:\n"
|
| 496 |
+
f"{pprint.pformat({k: type(v) for k, v in zip(inputs, args)})}"
|
| 497 |
+
f"\nmodified input types:\n"
|
| 498 |
+
f"{pprint.pformat({k: type(v) for k, v in session_run_input.items()})}"
|
| 499 |
+
f"\ninputs:\n{pprint.pformat(session_run_input)}\n{model}"
|
| 500 |
+
) from e
|
| 501 |
+
|
| 502 |
+
# Map ORT output values to the onnxscript representation-type.
|
| 503 |
+
return [_numpy_to_onnxscript_value(x) for x in result]
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def _schema_id(schema: onnx.defs.OpSchema) -> tuple[str, str, int]:
|
| 507 |
+
return schema.name, schema.domain, schema.since_version
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class ORTEvaluator(BaseEvaluator):
|
| 511 |
+
"""Evaluates ONNX ops using ONNX Runtime."""
|
| 512 |
+
|
| 513 |
+
def _eval(self, schema, inputs, attributes, closure):
|
| 514 |
+
return _call_ort(schema, inputs, attributes, closure)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
class OnnxReferenceRuntimeEvaluator(BaseEvaluator):
|
| 518 |
+
"""Evaluates ONNX ops using ONNX Runtime."""
|
| 519 |
+
|
| 520 |
+
def _eval(self, schema, inputs, attributes, closure):
|
| 521 |
+
model, session_run_input, adapted_inputs = _prepare_model_and_inputs_for_eager(
|
| 522 |
+
schema, inputs, attributes, closure
|
| 523 |
+
)
|
| 524 |
+
session = onnx.reference.ReferenceEvaluator(model)
|
| 525 |
+
try:
|
| 526 |
+
result = session.run(None, session_run_input)
|
| 527 |
+
except RuntimeError as e:
|
| 528 |
+
raise EagerModeError(
|
| 529 |
+
f"Unable to execute model operator {schema.name!r} due to {e!r}"
|
| 530 |
+
f"\ninput types:\n"
|
| 531 |
+
f"{pprint.pformat({k: type(v) for k, v in zip(adapted_inputs, inputs)})}"
|
| 532 |
+
f"\nmodified input types:\n"
|
| 533 |
+
f"{pprint.pformat({k: type(v) for k, v in session_run_input.items()})}"
|
| 534 |
+
f"\ninputs:\n{pprint.pformat(session_run_input)}\n{model}"
|
| 535 |
+
) from e
|
| 536 |
+
|
| 537 |
+
return [_numpy_to_onnxscript_value(x) for x in result]
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
ort_evaluator = ORTEvaluator()
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class ORTMixedEvaluator(ORTEvaluator):
|
| 544 |
+
"""Evaluates ONNX ops using ONNX Runtime, unless an overriding python implementation is registered.
|
| 545 |
+
|
| 546 |
+
This is useful for higher-order ops such as Scan and SequenceMap, allowing for
|
| 547 |
+
python-based debugging.
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
def __init__(self) -> None:
|
| 551 |
+
super().__init__()
|
| 552 |
+
self._python_ops: dict[tuple[str, str, int], Any] = {}
|
| 553 |
+
|
| 554 |
+
def use_graph_attribute(self, schema: onnx.defs.OpSchema) -> bool:
|
| 555 |
+
return _schema_id(schema) not in self._python_ops
|
| 556 |
+
|
| 557 |
+
def _eval(self, schema, inputs, attributes, closure):
|
| 558 |
+
schemaid = _schema_id(schema)
|
| 559 |
+
if schemaid in self._python_ops:
|
| 560 |
+
return self._python_ops[schemaid](inputs, attributes)
|
| 561 |
+
else:
|
| 562 |
+
return super()._eval(schema, inputs, attributes, closure)
|
| 563 |
+
|
| 564 |
+
def register(self, opset: values.Opset) -> Callable[[_T], _T]:
|
| 565 |
+
assert opset is not None
|
| 566 |
+
|
| 567 |
+
def decorator(function: _T) -> _T:
|
| 568 |
+
schema = opset[function.__name__]
|
| 569 |
+
self._python_ops[_schema_id(schema)] = function
|
| 570 |
+
return function
|
| 571 |
+
|
| 572 |
+
return decorator
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
ort_mixed_evaluator = ORTMixedEvaluator()
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
@ort_mixed_evaluator.register(opset=onnx_opset.opset18)
|
| 579 |
+
def SequenceMap(inputs: Sequence[Any], attributes: Mapping[str, Any]):
|
| 580 |
+
"""Evaluates a SequenceMap op."""
|
| 581 |
+
fun = attributes["body"]
|
| 582 |
+
|
| 583 |
+
def get_input_of(input_index, iter_num):
|
| 584 |
+
input = inputs[input_index]
|
| 585 |
+
if isinstance(input, list):
|
| 586 |
+
return input[iter_num]
|
| 587 |
+
return input
|
| 588 |
+
|
| 589 |
+
def get_input(iter_num):
|
| 590 |
+
return [get_input_of(input_index, iter_num) for input_index in range(len(inputs))]
|
| 591 |
+
|
| 592 |
+
return [fun(*(get_input(i))) for i in range(len(inputs[0]))]
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
# Used to control the default evaluator instance. A simple approach for now.
|
| 596 |
+
|
| 597 |
+
_default_evaluator: Evaluator = ort_evaluator
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def default() -> Evaluator:
|
| 601 |
+
"""Returns the default Evaluator default."""
|
| 602 |
+
return _default_evaluator
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def set_default(new_default: Evaluator) -> None:
|
| 606 |
+
"""Sets the current Evaluator default."""
|
| 607 |
+
global _default_evaluator # pylint: disable=global-statement
|
| 608 |
+
_default_evaluator = new_default
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
@contextlib.contextmanager
|
| 612 |
+
def default_as(temp_default: Evaluator):
|
| 613 |
+
"""Context manager that temporarily switches the default evaluator."""
|
| 614 |
+
old_default = _default_evaluator
|
| 615 |
+
set_default(temp_default)
|
| 616 |
+
try:
|
| 617 |
+
yield
|
| 618 |
+
finally:
|
| 619 |
+
set_default(old_default)
|
pythonProject/.venv/Lib/site-packages/onnxscript/irbuilder.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
# ruff: noqa: TID251
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import dataclasses
|
| 7 |
+
import io
|
| 8 |
+
import logging
|
| 9 |
+
import warnings
|
| 10 |
+
from typing import Any, Optional, Protocol, Sequence, Union
|
| 11 |
+
|
| 12 |
+
import onnx
|
| 13 |
+
from onnx import ValueInfoProto, helper
|
| 14 |
+
from onnx.defs import onnx_opset_version
|
| 15 |
+
|
| 16 |
+
import onnxscript
|
| 17 |
+
from onnxscript import type_annotation as ta
|
| 18 |
+
from onnxscript import values
|
| 19 |
+
from onnxscript._internal import version_utils
|
| 20 |
+
from onnxscript.onnx_types import ONNXType
|
| 21 |
+
from onnxscript.sourceinfo import SourceInfo
|
| 22 |
+
|
| 23 |
+
# A simple IR (Function, Stmt, Attr, Var):
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("onnxscript")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str):
|
| 29 |
+
"""Formats a sequence of objects into a string."""
|
| 30 |
+
return prefix + sep.join([formatter(x) for x in seq]) + suffix
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def select_ir_version(version: int, domain: str = "") -> int:
|
| 34 |
+
"""Selects a suitable ONNX ir_version for a given opset version."""
|
| 35 |
+
if domain == "":
|
| 36 |
+
domain = "ai.onnx"
|
| 37 |
+
if (domain, version) not in helper.OP_SET_ID_VERSION_MAP:
|
| 38 |
+
return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx")
|
| 39 |
+
return helper.OP_SET_ID_VERSION_MAP[domain, version]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class IRType:
|
| 43 |
+
def __init__(self):
|
| 44 |
+
self.onnx_type = onnx.TypeProto()
|
| 45 |
+
|
| 46 |
+
def to_type_proto(self):
|
| 47 |
+
return self.onnx_type
|
| 48 |
+
|
| 49 |
+
def __repr__(self) -> str:
|
| 50 |
+
return "IRType()"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class IRTensorType(IRType):
|
| 54 |
+
def __init__(self, elem_type: onnx.TensorProto.DataType) -> None:
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.onnx_type.tensor_type.elem_type = elem_type
|
| 57 |
+
|
| 58 |
+
def __repr__(self) -> str:
|
| 59 |
+
return f"IRTensorType({self.onnx_type.tensor_type.elem_type})"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class IRTypeLike(Protocol):
|
| 63 |
+
def to_type_proto(self) -> onnx.TypeProto:
|
| 64 |
+
"""Converts IR type representation to onnx.TypeProto"""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class IRVar:
|
| 68 |
+
"""A variable (representing a formal parameter)."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None:
|
| 71 |
+
if not isinstance(varname, str):
|
| 72 |
+
raise TypeError(f"varname must be a string not {type(varname)!r}.")
|
| 73 |
+
self.name = varname
|
| 74 |
+
self.info = sourceinfo
|
| 75 |
+
self.typeinfo = typeinfo
|
| 76 |
+
|
| 77 |
+
def __str__(self):
|
| 78 |
+
return self.name
|
| 79 |
+
|
| 80 |
+
def __repr__(self):
|
| 81 |
+
return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})"
|
| 82 |
+
|
| 83 |
+
def typed_str(self):
|
| 84 |
+
return f"{self.name} : {self.typeinfo}"
|
| 85 |
+
|
| 86 |
+
def to_value_info(self, use_default_type: bool = True):
|
| 87 |
+
"""Converts the content of this class into :class:`onnx.ValueInfoProto`.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
use_default_type: if True, use a default type if an explicit type
|
| 91 |
+
is not known. Otherwise, returns a ValueInfoProto without type.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
an instance of :class:`onnx.ValueInfoProto`
|
| 95 |
+
"""
|
| 96 |
+
if self.name is None:
|
| 97 |
+
raise ValueError(self.info.msg("name cannot be None."))
|
| 98 |
+
value_info_proto = ValueInfoProto()
|
| 99 |
+
value_info_proto.name = self.name
|
| 100 |
+
if self.typeinfo is not None:
|
| 101 |
+
value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto())
|
| 102 |
+
elif use_default_type:
|
| 103 |
+
value_info_proto.type.CopyFrom(IRType().to_type_proto())
|
| 104 |
+
return value_info_proto
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _opt_var_to_str(x):
|
| 108 |
+
return "" if x is None else str(x)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class IRAttributeValue:
|
| 112 |
+
"""An attribute value (representing an actual parameter).
|
| 113 |
+
|
| 114 |
+
Attributes:
|
| 115 |
+
name: The name of the attribute.
|
| 116 |
+
type: The type of the attribute.
|
| 117 |
+
attr_proto: The attribute proto.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, attrproto: onnx.AttributeProto) -> None:
|
| 121 |
+
self.attr_proto = attrproto
|
| 122 |
+
|
| 123 |
+
def __str__(self):
|
| 124 |
+
if self.attr_proto.HasField("ref_attr_name"):
|
| 125 |
+
return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}"
|
| 126 |
+
# self.name + " = " + self.value
|
| 127 |
+
return helper.printable_attribute(self.attr_proto)
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def name(self) -> str:
|
| 131 |
+
return self.attr_proto.name
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def type(self) -> onnx.AttributeProto.AttributeType:
|
| 135 |
+
return self.attr_proto.type
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclasses.dataclass(frozen=True)
|
| 139 |
+
class IRAttributeParameter:
|
| 140 |
+
"""An attribute parameter (representing a formal parameter).
|
| 141 |
+
|
| 142 |
+
It may or may not carry a default value.
|
| 143 |
+
|
| 144 |
+
Attributes:
|
| 145 |
+
name: The name of the attribute.
|
| 146 |
+
type: The type of the attribute.
|
| 147 |
+
default_value: The default value of the attribute.
|
| 148 |
+
has_default: Whether the attribute has a default value.
|
| 149 |
+
attr_proto: The attribute proto.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
name: str
|
| 153 |
+
type: onnx.AttributeProto.AttributeType
|
| 154 |
+
default_value: str | int | float | None = None
|
| 155 |
+
|
| 156 |
+
# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
|
| 157 |
+
|
| 158 |
+
def __str__(self):
|
| 159 |
+
if self.has_default:
|
| 160 |
+
return helper.printable_attribute(self.attr_proto)
|
| 161 |
+
# TODO(justinchuby): Include a readable type name.
|
| 162 |
+
return self.name
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def has_default(self):
|
| 166 |
+
return self.default_value is not None
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def attr_proto(self) -> onnx.AttributeProto:
|
| 170 |
+
if not self.has_default:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"Attribute has no default value. Only attributes with default "
|
| 173 |
+
"values can be converted to AttributeProto."
|
| 174 |
+
)
|
| 175 |
+
if version_utils.onnx_older_than("1.15"):
|
| 176 |
+
# TODO(after 1.14 is deprecated): Remove this branch.
|
| 177 |
+
# Argument 'attr_type' was added after version 1.14.
|
| 178 |
+
return helper.make_attribute(self.name, self.default_value)
|
| 179 |
+
# pylint: disable=unexpected-keyword-arg
|
| 180 |
+
return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg]
|
| 181 |
+
# pylint: enable=unexpected-keyword-arg
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class IRStmt:
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
result: Sequence[str],
|
| 188 |
+
callee: values.Op,
|
| 189 |
+
args: Sequence[Optional[str]],
|
| 190 |
+
attrs: Sequence[IRAttributeValue],
|
| 191 |
+
sub_functions=None,
|
| 192 |
+
) -> None:
|
| 193 |
+
if not isinstance(callee, values.Op):
|
| 194 |
+
raise TypeError(f"Unexpected type {type(callee)} for callee.")
|
| 195 |
+
self.result = result
|
| 196 |
+
self.callee = callee
|
| 197 |
+
self.args = args
|
| 198 |
+
self.attrs = attrs
|
| 199 |
+
self.functions = sub_functions or {}
|
| 200 |
+
|
| 201 |
+
def __str__(self):
|
| 202 |
+
if isinstance(self.result, str):
|
| 203 |
+
logger.debug("unexpected str type for self.result where type(self)=%r", type(self))
|
| 204 |
+
lhs = ", ".join(self.result)
|
| 205 |
+
attrs = ""
|
| 206 |
+
if self.attrs:
|
| 207 |
+
attrs = _format(self.attrs, "<", ", ", ">")
|
| 208 |
+
|
| 209 |
+
args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
|
| 210 |
+
domain = self.callee.opset.domain
|
| 211 |
+
opname = self.callee.name
|
| 212 |
+
callee = f"{domain}.{opname}" if (domain != "") else opname
|
| 213 |
+
return f"{lhs} = {callee} {attrs}{args}"
|
| 214 |
+
|
| 215 |
+
def debug_print(self):
|
| 216 |
+
if logger.isEnabledFor(logging.DEBUG):
|
| 217 |
+
logger.debug("%s: %s", type(self), str(self))
|
| 218 |
+
|
| 219 |
+
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
|
| 220 |
+
n = helper.make_node(
|
| 221 |
+
self.callee.name,
|
| 222 |
+
[_opt_var_to_str(x) for x in self.args],
|
| 223 |
+
[str(x) for x in self.result],
|
| 224 |
+
domain=self.callee.opset.domain,
|
| 225 |
+
name=node_name,
|
| 226 |
+
)
|
| 227 |
+
for a in self.attrs:
|
| 228 |
+
n.attribute.append(a.attr_proto)
|
| 229 |
+
return n
|
| 230 |
+
|
| 231 |
+
@property
|
| 232 |
+
def output_names(self) -> Sequence[str]:
|
| 233 |
+
"""Returns the list of variables assigned to by this statement."""
|
| 234 |
+
return [str(x) for x in self.result]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class IRFunction:
|
| 238 |
+
"""Represents a function in the IR."""
|
| 239 |
+
|
| 240 |
+
def __init__(self, name: str, domain: str = "") -> None:
|
| 241 |
+
self.domain = domain
|
| 242 |
+
self.name = name
|
| 243 |
+
self.outputs: list[IRVar] = []
|
| 244 |
+
self.stmts: list[IRStmt] = []
|
| 245 |
+
self.called_functions: dict[str, onnx.FunctionProto] = {}
|
| 246 |
+
self.docstring: str = ""
|
| 247 |
+
# a dictionary of nested function-definitions
|
| 248 |
+
self.nested_functions: dict[str, IRFunction] = {}
|
| 249 |
+
self.outer_scope_variables: dict[Any, Any] = {}
|
| 250 |
+
self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = []
|
| 251 |
+
|
| 252 |
+
@property
|
| 253 |
+
def assigned_names(self) -> Sequence[str]:
|
| 254 |
+
"""Returns the list of variables assigned to by this function."""
|
| 255 |
+
return [v for stmt in self.stmts for v in stmt.output_names]
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def inputs(self) -> Sequence[IRVar]:
|
| 259 |
+
return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)]
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def attrs(self) -> Sequence[IRAttributeParameter]:
|
| 263 |
+
return [
|
| 264 |
+
attr
|
| 265 |
+
for attr in self.ordered_inputs_and_attrs
|
| 266 |
+
if isinstance(attr, IRAttributeParameter)
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
def __str__(self):
|
| 270 |
+
attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else ""
|
| 271 |
+
inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")")
|
| 272 |
+
outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")")
|
| 273 |
+
stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n")
|
| 274 |
+
return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"
|
| 275 |
+
|
| 276 |
+
def append_docstring(self, docstring):
|
| 277 |
+
self.docstring += docstring
|
| 278 |
+
|
| 279 |
+
def append_stmt(self, stmt: IRStmt) -> None:
|
| 280 |
+
self.stmts.append(stmt)
|
| 281 |
+
|
| 282 |
+
def append_input(self, name: IRVar) -> None:
|
| 283 |
+
self.ordered_inputs_and_attrs.append(name)
|
| 284 |
+
|
| 285 |
+
def append_output(self, name: IRVar) -> None:
|
| 286 |
+
self.outputs.append(name)
|
| 287 |
+
|
| 288 |
+
def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
|
| 289 |
+
self.ordered_inputs_and_attrs.append(attr)
|
| 290 |
+
|
| 291 |
+
def debug_print(self):
|
| 292 |
+
if logger.isEnabledFor(logging.DEBUG):
|
| 293 |
+
st = io.StringIO()
|
| 294 |
+
for s in self.stmts:
|
| 295 |
+
for attr in s.attrs:
|
| 296 |
+
if attr.attr_proto.HasField("g"):
|
| 297 |
+
st.write(helper.printable_graph(attr.attr_proto.g))
|
| 298 |
+
st.write("\n")
|
| 299 |
+
|
| 300 |
+
def add_called_function(self, fun: values.OnnxFunction) -> None:
|
| 301 |
+
for name, fct in fun.function_ir.called_functions.items():
|
| 302 |
+
if name in self.called_functions:
|
| 303 |
+
continue
|
| 304 |
+
self.called_functions[name] = fct
|
| 305 |
+
if fun.name in self.called_functions:
|
| 306 |
+
# Already added.
|
| 307 |
+
return
|
| 308 |
+
try:
|
| 309 |
+
proto = fun.to_function_proto()
|
| 310 |
+
except (TypeError, AttributeError) as e:
|
| 311 |
+
raise TypeError(f"Issue with type f{type(fun)}.") from e
|
| 312 |
+
self.called_functions[fun.name] = proto
|
| 313 |
+
|
| 314 |
+
def add_nested_function(self, fun: IRFunction) -> None:
|
| 315 |
+
self.nested_functions[fun.name] = fun
|
| 316 |
+
|
| 317 |
+
def to_model_proto(
|
| 318 |
+
self,
|
| 319 |
+
functions=None,
|
| 320 |
+
io_types: Optional[ONNXType] = None,
|
| 321 |
+
input_types: Optional[Sequence[ONNXType]] = None,
|
| 322 |
+
output_types: Optional[Sequence[ONNXType]] = None,
|
| 323 |
+
value_infos: dict[str, ONNXType] | None = None,
|
| 324 |
+
**kwargs,
|
| 325 |
+
) -> onnx.ModelProto:
|
| 326 |
+
"""Converts this instance into a `onnx.ModelProto`.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
functions: A list of functions to include in the model.
|
| 330 |
+
By default, all functions called at least once are included.
|
| 331 |
+
io_types: When specified, all the inputs/outputs of the model
|
| 332 |
+
are set to be of this type.
|
| 333 |
+
input_types: When specified, all the inputs of the model
|
| 334 |
+
are set to be of the corresponding type in this list.
|
| 335 |
+
output_types: When specified, all the outputs of the model
|
| 336 |
+
are set to be of the corresponding type in this list.
|
| 337 |
+
value_infos: A dictionary mapping intermediate variable names to ONNX types.
|
| 338 |
+
Used to set value_info for intermediate variables.
|
| 339 |
+
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
An instance of :class:`onnx.ModelProto`.
|
| 343 |
+
"""
|
| 344 |
+
value_infos = (
|
| 345 |
+
[
|
| 346 |
+
onnx.helper.make_value_info(name, type.to_type_proto())
|
| 347 |
+
for name, type in value_infos.items()
|
| 348 |
+
]
|
| 349 |
+
if value_infos
|
| 350 |
+
else None
|
| 351 |
+
)
|
| 352 |
+
graph, sub_functions = self.to_graph_and_functions(
|
| 353 |
+
use_default_type=False, value_infos=value_infos
|
| 354 |
+
)
|
| 355 |
+
if io_types is not None:
|
| 356 |
+
for input in graph.input:
|
| 357 |
+
if not input.HasField("type"):
|
| 358 |
+
input.type.CopyFrom(io_types.to_type_proto())
|
| 359 |
+
for output in graph.output:
|
| 360 |
+
if not output.HasField("type"):
|
| 361 |
+
output.type.CopyFrom(io_types.to_type_proto())
|
| 362 |
+
if input_types is not None:
|
| 363 |
+
for input, type in zip(graph.input, input_types):
|
| 364 |
+
input.type.CopyFrom(type.to_type_proto())
|
| 365 |
+
if output_types is not None:
|
| 366 |
+
for output, type in zip(graph.output, output_types):
|
| 367 |
+
output.type.CopyFrom(type.to_type_proto())
|
| 368 |
+
if functions is None:
|
| 369 |
+
functions = sub_functions.values()
|
| 370 |
+
else:
|
| 371 |
+
|
| 372 |
+
def to_proto(f):
|
| 373 |
+
if isinstance(f, onnx.FunctionProto):
|
| 374 |
+
return f
|
| 375 |
+
if isinstance(f, onnxscript.OnnxFunction):
|
| 376 |
+
return f.to_function_proto()
|
| 377 |
+
raise TypeError("Expected a value of type FunctionProto of OnnxFunction")
|
| 378 |
+
|
| 379 |
+
functions = [to_proto(f) for f in functions]
|
| 380 |
+
|
| 381 |
+
opsets = {}
|
| 382 |
+
for n in self.stmts:
|
| 383 |
+
if n.callee.opset.domain not in opsets:
|
| 384 |
+
opsets[n.callee.opset.domain] = n.callee.opset.version
|
| 385 |
+
|
| 386 |
+
for proto in functions:
|
| 387 |
+
if proto.domain not in opsets:
|
| 388 |
+
opsets[proto.domain] = 1
|
| 389 |
+
# TODO(rama): Handle conflicts with appropriate error/warning message.
|
| 390 |
+
for opset in proto.opset_import:
|
| 391 |
+
if opset.domain not in opsets:
|
| 392 |
+
opsets[opset.domain] = opset.version
|
| 393 |
+
|
| 394 |
+
if "" not in opsets:
|
| 395 |
+
# No operator is using the standard opset.
|
| 396 |
+
# A default value is given.
|
| 397 |
+
opsets[""] = onnx_opset_version()
|
| 398 |
+
|
| 399 |
+
if "ir_version" not in kwargs:
|
| 400 |
+
kwargs["ir_version"] = select_ir_version(opsets[""])
|
| 401 |
+
opset_imports = [
|
| 402 |
+
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
|
| 403 |
+
]
|
| 404 |
+
|
| 405 |
+
return helper.make_model(
|
| 406 |
+
graph, opset_imports=opset_imports, functions=functions, **kwargs
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def to_graph_and_functions(
|
| 410 |
+
self,
|
| 411 |
+
use_default_type: bool = True,
|
| 412 |
+
value_infos: Sequence[ValueInfoProto] | None = None,
|
| 413 |
+
) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
|
| 414 |
+
"""Converts this instance into a `onnx.GraphProto` and a map from
|
| 415 |
+
function-name to `onnx.FunctionProto`.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
use_default_type: if True, the function uses a default type
|
| 419 |
+
for inputs and outputs that do not have a type
|
| 420 |
+
value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added
|
| 421 |
+
to the graph.
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto`
|
| 425 |
+
"""
|
| 426 |
+
called_functions: dict[str, onnx.FunctionProto] = {}
|
| 427 |
+
for s in self.stmts:
|
| 428 |
+
called_functions.update(s.functions)
|
| 429 |
+
called_functions.update(self.called_functions)
|
| 430 |
+
graph = helper.make_graph(
|
| 431 |
+
[s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)],
|
| 432 |
+
self.name,
|
| 433 |
+
[x.to_value_info(use_default_type) for x in self.inputs],
|
| 434 |
+
[y.to_value_info(use_default_type) for y in self.outputs],
|
| 435 |
+
value_info=value_infos,
|
| 436 |
+
)
|
| 437 |
+
return graph, called_functions
|
| 438 |
+
|
| 439 |
+
def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto:
|
| 440 |
+
"""Converts this instance into a `onnx.GraphProto`.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
use_default_type: if True, the function uses a default type
|
| 444 |
+
for inputs and outputs that do not have a type
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
an instance of :class:`onnx.GraphProto`
|
| 448 |
+
"""
|
| 449 |
+
graph, _ = self.to_graph_and_functions(use_default_type=use_default_type)
|
| 450 |
+
return graph
|
| 451 |
+
|
| 452 |
+
def get_opset_import(self) -> dict[str, int]:
|
| 453 |
+
func_opset_imports = {}
|
| 454 |
+
for s in self.stmts:
|
| 455 |
+
if s.callee.opset.domain not in func_opset_imports:
|
| 456 |
+
func_opset_imports[s.callee.opset.domain] = s.callee.opset.version
|
| 457 |
+
elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version:
|
| 458 |
+
warnings.warn(
|
| 459 |
+
f"There is a version conflict in domain: {s.callee.opset.domain!r}, "
|
| 460 |
+
f"with {self.name!r}.",
|
| 461 |
+
category=UserWarning,
|
| 462 |
+
stacklevel=1,
|
| 463 |
+
)
|
| 464 |
+
return func_opset_imports
|
| 465 |
+
|
| 466 |
+
def to_function_proto(self) -> onnx.FunctionProto:
|
| 467 |
+
"""Converts this instance into a `onnx.FunctionProto`.
|
| 468 |
+
|
| 469 |
+
Note: Default values for attributes are an experimental feature in ONNX.
|
| 470 |
+
Conversion ignores default values for attributes if the ONNX version installed
|
| 471 |
+
doesn't support it.
|
| 472 |
+
"""
|
| 473 |
+
opsets = self.get_opset_import()
|
| 474 |
+
nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)]
|
| 475 |
+
for n in nodes:
|
| 476 |
+
if n.domain not in opsets:
|
| 477 |
+
opsets[n.domain] = 1 # TODO: how to get n.version?
|
| 478 |
+
opset_imports = [
|
| 479 |
+
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
|
| 480 |
+
]
|
| 481 |
+
|
| 482 |
+
attribute_names = [attr.name for attr in self.attrs if not attr.has_default]
|
| 483 |
+
|
| 484 |
+
f = helper.make_function(
|
| 485 |
+
self.domain,
|
| 486 |
+
self.name,
|
| 487 |
+
inputs=[x.name for x in self.inputs],
|
| 488 |
+
outputs=[y.name for y in self.outputs],
|
| 489 |
+
nodes=nodes,
|
| 490 |
+
opset_imports=opset_imports, # TODO
|
| 491 |
+
attributes=attribute_names,
|
| 492 |
+
doc_string=self.docstring,
|
| 493 |
+
)
|
| 494 |
+
# In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead
|
| 495 |
+
if hasattr(f, "attribute_proto"):
|
| 496 |
+
f.attribute_proto.extend(
|
| 497 |
+
[attr.attr_proto for attr in self.attrs if attr.has_default]
|
| 498 |
+
)
|
| 499 |
+
return f
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# IRBuilder: abstracts out details of the IR in the python-to-IR converter
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class IRBuilder:
|
| 506 |
+
def __init__(self):
|
| 507 |
+
self.functions = {}
|
| 508 |
+
|
| 509 |
+
def new_function(self, name: str, domain: str = "", register: bool = False) -> IRFunction:
|
| 510 |
+
if register and (domain, name) in self.functions:
|
| 511 |
+
raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.")
|
| 512 |
+
function = IRFunction(name, domain)
|
| 513 |
+
if register:
|
| 514 |
+
self.functions[domain, name] = function
|
| 515 |
+
return function
|
| 516 |
+
|
| 517 |
+
def add_docstring(self, fn: IRFunction, docstring: str):
|
| 518 |
+
fn.append_docstring(docstring)
|
| 519 |
+
|
| 520 |
+
def add_stmt(
|
| 521 |
+
self,
|
| 522 |
+
fn: IRFunction,
|
| 523 |
+
results: Sequence[str],
|
| 524 |
+
callee: values.Op,
|
| 525 |
+
args: Sequence[Optional[str]],
|
| 526 |
+
attrs: Sequence[IRAttributeValue],
|
| 527 |
+
sub_functions=None,
|
| 528 |
+
) -> None:
|
| 529 |
+
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
|
| 530 |
+
fn.append_stmt(stmt)
|
| 531 |
+
|
| 532 |
+
def add_input(
|
| 533 |
+
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
|
| 534 |
+
) -> None:
|
| 535 |
+
var = IRVar(varname, type, info)
|
| 536 |
+
fn.append_input(var)
|
| 537 |
+
|
| 538 |
+
def add_attr_parameter(
|
| 539 |
+
self,
|
| 540 |
+
fn: IRFunction,
|
| 541 |
+
varname: str,
|
| 542 |
+
attribute_type: onnx.AttributeProto.AttributeType,
|
| 543 |
+
default_value: int | float | str | None,
|
| 544 |
+
) -> None:
|
| 545 |
+
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
|
| 546 |
+
|
| 547 |
+
def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
|
| 548 |
+
var = IRVar(varname, typeinfo, sourceinfo)
|
| 549 |
+
fn.append_output(var)
|
| 550 |
+
|
| 551 |
+
def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue:
|
| 552 |
+
return IRAttributeValue(attrproto)
|
| 553 |
+
|
| 554 |
+
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
|
| 555 |
+
proto = onnx.AttributeProto()
|
| 556 |
+
proto.name = attrname
|
| 557 |
+
proto.ref_attr_name = refname
|
| 558 |
+
attr_type = ta.pytype_to_attrtype(pytype)
|
| 559 |
+
assert attr_type is not None
|
| 560 |
+
proto.type = attr_type
|
| 561 |
+
return IRAttributeValue(proto)
|
pythonProject/.venv/Lib/site-packages/onnxscript/main.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
# pylint disable: protected-access
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import ast
|
| 7 |
+
import inspect
|
| 8 |
+
import sys
|
| 9 |
+
from typing import Any, Callable, Optional, Sequence, TypeVar
|
| 10 |
+
|
| 11 |
+
from typing_extensions import ParamSpec
|
| 12 |
+
|
| 13 |
+
import onnxscript
|
| 14 |
+
from onnxscript import converter, ir, irbuilder, values
|
| 15 |
+
from onnxscript._internal import ast_utils
|
| 16 |
+
|
| 17 |
+
_R = TypeVar("_R")
|
| 18 |
+
_P = ParamSpec("_P")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def script_check(
|
| 22 |
+
f: ast.FunctionDef,
|
| 23 |
+
opset: values.Opset,
|
| 24 |
+
global_names: dict[str, Any],
|
| 25 |
+
source: str,
|
| 26 |
+
default_opset: Optional[values.Opset] = None,
|
| 27 |
+
) -> irbuilder.IRFunction:
|
| 28 |
+
"""Check that a function falls into the ONNXScript subset of Python."""
|
| 29 |
+
# See if conversion succeeds.
|
| 30 |
+
# TODO: cleanup Converter interface/API, separating checker from
|
| 31 |
+
# converter
|
| 32 |
+
convert = converter.Converter(
|
| 33 |
+
opset=opset,
|
| 34 |
+
global_names=global_names,
|
| 35 |
+
source=source,
|
| 36 |
+
default_opset=default_opset,
|
| 37 |
+
)
|
| 38 |
+
return convert.translate_function_def(f)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def script(
|
| 42 |
+
opset: Optional[values.Opset] = None,
|
| 43 |
+
default_opset: Optional[values.Opset] = None,
|
| 44 |
+
**kwargs: Any,
|
| 45 |
+
) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]:
|
| 46 |
+
"""Main decorator. Declares a function as an onnx function.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
opset: Opset the function belongs to (see :ref:`l-api-opsets`).
|
| 50 |
+
default_opset: Opset to use for operators not in the function's opset.
|
| 51 |
+
kwargs: Additional keyword arguments.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
an instance of :class:`onnxscript.values.OnnxFunction`
|
| 55 |
+
|
| 56 |
+
Example:
|
| 57 |
+
::
|
| 58 |
+
|
| 59 |
+
@script()
|
| 60 |
+
def log2(x):
|
| 61 |
+
one = op.Constant(value=make_tensor('one', TensorProto.FLOAT, [1], [1]))
|
| 62 |
+
return op.Div(op.Log(x), op.CastLike(op.Log(cst), x))
|
| 63 |
+
|
| 64 |
+
Or:
|
| 65 |
+
|
| 66 |
+
::
|
| 67 |
+
|
| 68 |
+
from onnxscript.onnx_opset import opset16
|
| 69 |
+
|
| 70 |
+
@script(opset16)
|
| 71 |
+
def log2(x):
|
| 72 |
+
one = op.Constant(value=make_tensor('one', TensorProto.FLOAT, [1], [1]))
|
| 73 |
+
return op.Div(op.Log(x), op.CastLike(op.Log(cst), x))
|
| 74 |
+
"""
|
| 75 |
+
opset = opset or values.Opset("this", 1)
|
| 76 |
+
if not isinstance(opset, values.Opset):
|
| 77 |
+
raise TypeError(
|
| 78 |
+
"Script parameter must be an opset. Did you use @script instead of @script()?"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]:
|
| 82 |
+
if not inspect.isfunction(f):
|
| 83 |
+
raise TypeError("The ONNXScript decorator should be applied to functions only.")
|
| 84 |
+
|
| 85 |
+
src, f_ast = ast_utils.get_src_and_ast(f)
|
| 86 |
+
# The script should be compiled using the globals/locals at the definition site.
|
| 87 |
+
# This allows the script to reference names defined outside the script,
|
| 88 |
+
# which is used for a few different purposes.
|
| 89 |
+
# The following is an approximate solution that works for normal use.
|
| 90 |
+
module = inspect.getmodule(f)
|
| 91 |
+
closure = inspect.getclosurevars(f)
|
| 92 |
+
env = module.__dict__.copy()
|
| 93 |
+
env.update(closure.nonlocals)
|
| 94 |
+
result = script_check(f_ast, opset, env, src, default_opset=default_opset)
|
| 95 |
+
# TODO: add transformations.
|
| 96 |
+
return onnxscript.OnnxFunction(opset, f, result, src, kwargs)
|
| 97 |
+
|
| 98 |
+
return transform
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def graph() -> Callable[[Callable], values.OnnxClosure]:
|
| 102 |
+
"""A parametric decorator used to annotate nested-functions that are used
|
| 103 |
+
as graph-attributes.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
A decorator that returns its input function, but attaches a graph_proto
|
| 107 |
+
attribute representing the input function. The translation is not
|
| 108 |
+
done at this time, but previously when the outer-level function
|
| 109 |
+
was translated to an OnnxFunction. The decorator just looks up
|
| 110 |
+
and retrieves the GraphProto representation previously generated.
|
| 111 |
+
|
| 112 |
+
Example:
|
| 113 |
+
::
|
| 114 |
+
|
| 115 |
+
@script()
|
| 116 |
+
def cumulative_sum(X: INT64['N']):
|
| 117 |
+
|
| 118 |
+
# Translation of cumulative_sum by @script will also translate Sum
|
| 119 |
+
# into a GraphProto, which will be stored in the OnnxFunction generated
|
| 120 |
+
# for cumulative_sum. At run-time (in eager-mode), the @graph decorator
|
| 121 |
+
# retrieves the pre-computed GraphProto and attaches it to the Sum function.
|
| 122 |
+
@graph()
|
| 123 |
+
def Sum(sum_in, next):
|
| 124 |
+
sum_out = sum_in + next
|
| 125 |
+
scan_out = op.Identity(sum_out)
|
| 126 |
+
return sum_out, scan_out
|
| 127 |
+
zero = op.Constant(value_int=0)
|
| 128 |
+
# The call to higher-order operator Scan below uses the above function
|
| 129 |
+
# Sum as a graph-attribute.
|
| 130 |
+
all_sum, result = op.Scan (zero, X, body=Sum, num_scan_inputs=1)
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
"""
|
| 134 |
+
# This is a bit fragile. We want to get the ONNXFunction object representing
|
| 135 |
+
# the outer-scope ONNXScript function from the execution stack. The caller of
|
| 136 |
+
# @graph is the original script function (cumulative_sum in the above example),
|
| 137 |
+
# and the caller of that function is the wrapper function/method in the
|
| 138 |
+
# corresponding OnnxFunction object.
|
| 139 |
+
# Currently, there is no support for eager-mode execution of nested functions,
|
| 140 |
+
# so we don't need to handle doubly nested functions (e.g., a function defined
|
| 141 |
+
# inside Sum in the above example).
|
| 142 |
+
|
| 143 |
+
function_frame = sys._getframe(1) # pylint: disable=protected-access
|
| 144 |
+
wrapper_frame = sys._getframe(3) # pylint: disable=protected-access
|
| 145 |
+
onnx_function = wrapper_frame.f_locals["self"]
|
| 146 |
+
nested_functions = onnx_function.function_ir.nested_functions
|
| 147 |
+
|
| 148 |
+
def transform(f: Callable) -> values.OnnxClosure:
|
| 149 |
+
return values.OnnxClosure(nested_functions[f.__name__], function_frame, f)
|
| 150 |
+
|
| 151 |
+
return transform
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def is_converted_fun(f: Any) -> bool:
|
| 155 |
+
"""Return True if f is a function converted by onnxscript decorator."""
|
| 156 |
+
return isinstance(f, onnxscript.OnnxFunction)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def export_onnx_lib(functions: Sequence[values.OnnxFunction], filename: str) -> None:
|
| 160 |
+
# Since we don't yet have LibProto defined, we use a ModelProto as a temporary
|
| 161 |
+
# container for the list of functions exported as a library, with an empty graph
|
| 162 |
+
# and dummy opset_imports.
|
| 163 |
+
|
| 164 |
+
# TODO(justinchuby): This function is not well supported. We should consider removing it
|
| 165 |
+
model = ir.Model(
|
| 166 |
+
ir.Graph(
|
| 167 |
+
inputs=[],
|
| 168 |
+
outputs=[],
|
| 169 |
+
nodes=[],
|
| 170 |
+
opset_imports={"": 15},
|
| 171 |
+
),
|
| 172 |
+
functions=[ir.serde.deserialize_function(f.to_function_proto()) for f in functions],
|
| 173 |
+
ir_version=10,
|
| 174 |
+
producer_name="p2o",
|
| 175 |
+
)
|
| 176 |
+
ir.save(model, filename)
|
pythonProject/.venv/Lib/site-packages/onnxscript/onnx_types.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import abc
|
| 7 |
+
from typing import ClassVar, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import onnx
|
| 10 |
+
import onnx_ir as ir
|
| 11 |
+
|
| 12 |
+
_DType = ir.DataType
|
| 13 |
+
_DimType = Union[int, str, type(None)]
|
| 14 |
+
_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]
|
| 15 |
+
|
| 16 |
+
_tensor_type_shape_cache: dict[_DType, TensorType] = {}
|
| 17 |
+
tensor_type_registry: dict[_DType, TensorType] = {}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _check_dim(dim):
|
| 21 |
+
if not isinstance(dim, (int, str, type(None))):
|
| 22 |
+
raise TypeError(f"Invalid dimension {dim}")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _check_shape(shape):
|
| 26 |
+
if isinstance(shape, tuple):
|
| 27 |
+
for dim in shape:
|
| 28 |
+
_check_dim(dim)
|
| 29 |
+
elif shape != Ellipsis:
|
| 30 |
+
_check_dim(shape)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TensorType(abc.ABC):
|
| 34 |
+
"""ONNX Script representation of a tensor type supporting shape annotations.
|
| 35 |
+
|
| 36 |
+
A scalar-tensor of rank 0:
|
| 37 |
+
::
|
| 38 |
+
|
| 39 |
+
tensor: FLOAT
|
| 40 |
+
|
| 41 |
+
A tensor of unknown rank:
|
| 42 |
+
::
|
| 43 |
+
|
| 44 |
+
tensor: FLOAT[...]
|
| 45 |
+
|
| 46 |
+
A tensor of rank 2 of unknown dimensions, with symbolic names:
|
| 47 |
+
::
|
| 48 |
+
|
| 49 |
+
tensor: FLOAT['M', 'N']
|
| 50 |
+
|
| 51 |
+
A tensor of rank 2 of known dimensions:
|
| 52 |
+
::
|
| 53 |
+
|
| 54 |
+
tensor: FLOAT[128, 1024]
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
dtype: ClassVar[_DType]
|
| 58 |
+
shape: ClassVar[Optional[_ShapeType]]
|
| 59 |
+
|
| 60 |
+
def __new__(cls):
|
| 61 |
+
raise NotImplementedError("TensorTypes cannot be instantiated")
|
| 62 |
+
|
| 63 |
+
def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None):
|
| 64 |
+
cls.dtype = dtype
|
| 65 |
+
cls.shape = shape
|
| 66 |
+
if shape is None:
|
| 67 |
+
existing_cls = tensor_type_registry.get(dtype)
|
| 68 |
+
if existing_cls is not None:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Invalid usage: subclass {existing_cls!r} "
|
| 71 |
+
f"already defined for dtype={dtype}"
|
| 72 |
+
)
|
| 73 |
+
tensor_type_registry[dtype] = cls
|
| 74 |
+
else:
|
| 75 |
+
_check_shape(shape)
|
| 76 |
+
|
| 77 |
+
def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]:
|
| 78 |
+
if cls.shape is not None:
|
| 79 |
+
raise ValueError("Invalid usage: shape already specified.")
|
| 80 |
+
if shape is None:
|
| 81 |
+
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
|
| 82 |
+
shape = (None,)
|
| 83 |
+
key = (cls.dtype, shape)
|
| 84 |
+
shaped_type = _tensor_type_shape_cache.get(key)
|
| 85 |
+
if shaped_type is None:
|
| 86 |
+
shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape)
|
| 87 |
+
_tensor_type_shape_cache[key] = shaped_type
|
| 88 |
+
return shaped_type
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def to_type_proto(cls) -> onnx.TypeProto:
|
| 92 |
+
if cls.shape is None:
|
| 93 |
+
shape = () # "FLOAT" is treated as a scalar
|
| 94 |
+
elif cls.shape is Ellipsis:
|
| 95 |
+
shape = None # "FLOAT[...]" is a tensor of unknown rank
|
| 96 |
+
elif isinstance(cls.shape, tuple):
|
| 97 |
+
shape = cls.shape # example: "FLOAT[10,20]"
|
| 98 |
+
else:
|
| 99 |
+
shape = [cls.shape] # example: "FLOAT[10]"
|
| 100 |
+
return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def to_string(cls) -> str:
|
| 104 |
+
return f"tensor({cls.__name__.lower()})"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class FLOAT(TensorType, dtype=ir.DataType.FLOAT):
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class UINT8(TensorType, dtype=ir.DataType.UINT8):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class INT8(TensorType, dtype=ir.DataType.INT8):
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class UINT16(TensorType, dtype=ir.DataType.UINT16):
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class INT16(TensorType, dtype=ir.DataType.INT16):
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class INT32(TensorType, dtype=ir.DataType.INT32):
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class INT64(TensorType, dtype=ir.DataType.INT64):
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class STRING(TensorType, dtype=ir.DataType.STRING):
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class BOOL(TensorType, dtype=ir.DataType.BOOL):
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class FLOAT16(TensorType, dtype=ir.DataType.FLOAT16):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class DOUBLE(TensorType, dtype=ir.DataType.DOUBLE):
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class UINT32(TensorType, dtype=ir.DataType.UINT32):
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class UINT64(TensorType, dtype=ir.DataType.UINT64):
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class COMPLEX64(TensorType, dtype=ir.DataType.COMPLEX64):
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class COMPLEX128(TensorType, dtype=ir.DataType.COMPLEX128):
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class BFLOAT16(TensorType, dtype=ir.DataType.BFLOAT16):
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class FLOAT8E4M3FN(TensorType, dtype=ir.DataType.FLOAT8E4M3FN):
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class FLOAT8E4M3FNUZ(TensorType, dtype=ir.DataType.FLOAT8E4M3FNUZ):
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class FLOAT8E5M2(TensorType, dtype=ir.DataType.FLOAT8E5M2):
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class FLOAT8E5M2FNUZ(TensorType, dtype=ir.DataType.FLOAT8E5M2FNUZ):
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class INT4(TensorType, dtype=ir.DataType.INT4):
|
| 188 |
+
pass
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class UINT4(TensorType, dtype=ir.DataType.UINT4):
|
| 192 |
+
pass
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1):
|
| 196 |
+
pass
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str:
|
| 200 |
+
"""Converts an onnx type into the string representation of the type in *onnxscript*.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
onnx_type: an instance of onnx TypeProto
|
| 204 |
+
reversible: if True, the conversion produces only types that are
|
| 205 |
+
recognized by the onnxscript converter.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
The string representation of the type in onnxscript
|
| 209 |
+
|
| 210 |
+
Raises:
|
| 211 |
+
...
|
| 212 |
+
"""
|
| 213 |
+
if onnx_type.HasField("tensor_type"):
|
| 214 |
+
elem_type = onnx_type.tensor_type.elem_type
|
| 215 |
+
name = onnx.TensorProto.DataType.Name(elem_type)
|
| 216 |
+
if onnx_type.tensor_type.HasField("shape"):
|
| 217 |
+
shape = []
|
| 218 |
+
for d in onnx_type.tensor_type.shape.dim:
|
| 219 |
+
if d.HasField("dim_value"):
|
| 220 |
+
shape.append(str(d.dim_value))
|
| 221 |
+
elif d.HasField("dim_param"):
|
| 222 |
+
shape.append(repr(d.dim_param))
|
| 223 |
+
else:
|
| 224 |
+
shape.append("None")
|
| 225 |
+
if not shape:
|
| 226 |
+
return name
|
| 227 |
+
return f"{name}[{','.join(shape)}]"
|
| 228 |
+
return f"{name}[...]"
|
| 229 |
+
if not reversible:
|
| 230 |
+
if onnx_type.HasField("sequence_type"):
|
| 231 |
+
elem_type = onnx_type.sequence_type.elem_type
|
| 232 |
+
return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]"
|
| 233 |
+
raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Currently, only tensor types are supported. Need to expand support for other ONNX types.
|
| 237 |
+
ONNXType = TensorType
|
pythonProject/.venv/Lib/site-packages/onnxscript/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Marker file for PEP-561 (inline types)
|
pythonProject/.venv/Lib/site-packages/onnxscript/sourceinfo.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# -------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
"""Source code information used for diagnostic messages."""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import ast
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SourceInfo:
|
| 15 |
+
"""Information about onnxscript source fragment, used for diagnostic messages."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
ast_node: ast.AST,
|
| 20 |
+
code: Optional[str] = None,
|
| 21 |
+
function_name: Optional[str] = None,
|
| 22 |
+
):
|
| 23 |
+
self.ast_node = ast_node
|
| 24 |
+
self.code = code
|
| 25 |
+
self.function_name = function_name
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def lineno(self):
|
| 29 |
+
return self.ast_node.lineno
|
| 30 |
+
|
| 31 |
+
def msg(self, error_message: str) -> str:
|
| 32 |
+
lineno = self.lineno
|
| 33 |
+
if self.function_name:
|
| 34 |
+
source_loc = f"Function '{self.function_name}', line {lineno}"
|
| 35 |
+
else:
|
| 36 |
+
source_loc = f"Line {lineno}"
|
| 37 |
+
|
| 38 |
+
if self.code:
|
| 39 |
+
lines = self.code.split("\n")
|
| 40 |
+
line = lines[lineno - 1]
|
| 41 |
+
marker_prefix = " " * (self.ast_node.col_offset)
|
| 42 |
+
source_line = f"{line}\n{marker_prefix}^\n"
|
| 43 |
+
else:
|
| 44 |
+
source_line = ""
|
| 45 |
+
|
| 46 |
+
return f"ERROR: {error_message}\nat: {source_loc}\n{source_line}"
|
| 47 |
+
|
| 48 |
+
def __str__(self) -> str:
|
| 49 |
+
raise ValueError("Cannot happen!")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
Formatter = Callable[[ast.AST, str], str]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def formatter(source_code: Optional[str]) -> Formatter:
|
| 56 |
+
def format(node: ast.AST, message: str) -> str:
|
| 57 |
+
return SourceInfo(node, source_code).msg(message)
|
| 58 |
+
|
| 59 |
+
return format
|
pythonProject/.venv/Lib/site-packages/onnxscript/tensor.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from onnxscript import ir, onnx_opset
|
| 11 |
+
from onnxscript._internal import autocast
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Tensor:
|
| 15 |
+
"""An implementation of ONNX Tensors, based on a wrapper around numpy arrays.
|
| 16 |
+
Serves to define overloaded ops with an ONNX/ONNXScript semantics.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, nparray: Optional[np.ndarray], opset=None):
|
| 20 |
+
if nparray is not None and not isinstance(nparray, np.ndarray):
|
| 21 |
+
raise TypeError(
|
| 22 |
+
f"Unexpected type {type(nparray)}. It must be a numpy array or None."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self._nparray = nparray
|
| 26 |
+
# FIXME(justinhuby): Create a better way to determine the opset version
|
| 27 |
+
self._opset: Any = opset or onnx_opset.opset18
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def value(self) -> np.ndarray:
|
| 31 |
+
if self._nparray is None:
|
| 32 |
+
raise ValueError("Tensor does not have a value.")
|
| 33 |
+
return self._nparray
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def rank(self) -> int:
|
| 37 |
+
return len(self.value.shape)
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def is_scalar(self) -> bool:
|
| 41 |
+
return self.rank == 0
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def shape(self) -> tuple[int, ...]:
|
| 45 |
+
return self.value.shape
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def dtype(self) -> np.dtype:
|
| 49 |
+
return self.value.dtype
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def onnx_dtype(self) -> int:
|
| 53 |
+
return ir.DataType.from_numpy(self.dtype)
|
| 54 |
+
|
| 55 |
+
def __repr__(self) -> str:
|
| 56 |
+
return f"{self.__class__.__name__}({self.value!r})"
|
| 57 |
+
|
| 58 |
+
def __bool__(self) -> bool:
|
| 59 |
+
return bool(self.value)
|
| 60 |
+
|
| 61 |
+
def __int__(self) -> int:
|
| 62 |
+
return int(self.value)
|
| 63 |
+
|
| 64 |
+
def __float__(self) -> float:
|
| 65 |
+
return float(self.value)
|
| 66 |
+
|
| 67 |
+
def __len__(self) -> int:
|
| 68 |
+
return self.shape[0]
|
| 69 |
+
|
| 70 |
+
def __index__(self) -> int:
|
| 71 |
+
return self.value.__index__()
|
| 72 |
+
|
| 73 |
+
def __getitem__(self, index):
|
| 74 |
+
op = self._opset
|
| 75 |
+
if op.version < 13:
|
| 76 |
+
raise RuntimeError("Indexing requires opset 13 or later.")
|
| 77 |
+
if not isinstance(index, tuple):
|
| 78 |
+
# Normalize representation to a tuple.
|
| 79 |
+
# A single index-value is equivalent to a tuple with a single element.
|
| 80 |
+
index = (index,)
|
| 81 |
+
if len(index) > self.rank:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Number of indices {len(index)} is greater than rank {self.rank}"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Promote integer indices to tensors of rank 0
|
| 87 |
+
index = [autocast.cast_pyvalue_to_os_tensor(x) for x in index]
|
| 88 |
+
# Process all elements in index
|
| 89 |
+
shape = self.shape
|
| 90 |
+
sliced_indices = []
|
| 91 |
+
scalar_indices = []
|
| 92 |
+
to_squeeze = []
|
| 93 |
+
non_scalar_indices = []
|
| 94 |
+
for axis_, s in enumerate(index):
|
| 95 |
+
if isinstance(s, slice):
|
| 96 |
+
if s.start is None and s.stop is None and s.step is None:
|
| 97 |
+
continue
|
| 98 |
+
if s.step is None or s.step > 0:
|
| 99 |
+
sliced_indices.append(
|
| 100 |
+
[
|
| 101 |
+
s.start or 0,
|
| 102 |
+
s.stop if s.stop is not None else shape[axis_],
|
| 103 |
+
axis_,
|
| 104 |
+
s.step or 1,
|
| 105 |
+
]
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
sliced_indices.append(
|
| 109 |
+
[
|
| 110 |
+
s.start if s.start is not None else (shape[axis_] - 1),
|
| 111 |
+
s.stop if s.stop is not None else -(shape[axis_] + 1),
|
| 112 |
+
axis_,
|
| 113 |
+
s.step,
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
elif isinstance(s, Tensor):
|
| 117 |
+
if s.is_scalar:
|
| 118 |
+
scalar_indices.append([s, s + 1, axis_, 1])
|
| 119 |
+
to_squeeze.append(axis_)
|
| 120 |
+
else:
|
| 121 |
+
non_scalar_indices.append((axis_, s))
|
| 122 |
+
else:
|
| 123 |
+
raise TypeError(f"Unexpected type {type(s)}: slice or int expected.")
|
| 124 |
+
|
| 125 |
+
# Non-scalar-indexing requires the use of ONNX Gather operation.
|
| 126 |
+
# Slicing can be implemented efficiently using ONNX's Slice operation.
|
| 127 |
+
# Scalar-indexing can be implemented using either Gather or with the Slice operation.
|
| 128 |
+
# We map scalar-indexing into the Slice operation, except in the special case
|
| 129 |
+
# of a single scalar-index (with no other sliced_index), which we map directly
|
| 130 |
+
# to a Gather.
|
| 131 |
+
|
| 132 |
+
if not (sliced_indices or scalar_indices or non_scalar_indices):
|
| 133 |
+
# Edge case: no index specified. Eg. A[:, :]
|
| 134 |
+
return op.Identity(self)
|
| 135 |
+
if not sliced_indices and len(scalar_indices) == 1:
|
| 136 |
+
# Special case of indexing along a single axis: A[i], A[:, i], A[:, :, i] etc.
|
| 137 |
+
# promote integer input to tensor
|
| 138 |
+
axis = to_squeeze[0]
|
| 139 |
+
index_value = index[axis]
|
| 140 |
+
# use Gather to perform indexing
|
| 141 |
+
result = op.Gather(self, index_value, axis=axis)
|
| 142 |
+
elif sliced_indices or scalar_indices:
|
| 143 |
+
sliced_indices = sliced_indices + scalar_indices
|
| 144 |
+
indices = np.array(sliced_indices, dtype=np.int64).T
|
| 145 |
+
starts = Tensor(indices[0])
|
| 146 |
+
ends = Tensor(indices[1])
|
| 147 |
+
axes = Tensor(indices[2])
|
| 148 |
+
steps = Tensor(indices[3])
|
| 149 |
+
result = op.Slice(self, starts, ends, axes, steps)
|
| 150 |
+
if to_squeeze:
|
| 151 |
+
result = Tensor(np.squeeze(result.value, axis=tuple(to_squeeze)))
|
| 152 |
+
else:
|
| 153 |
+
result = self
|
| 154 |
+
for axis, value in non_scalar_indices:
|
| 155 |
+
result = op.Gather(result, value, axis=axis)
|
| 156 |
+
|
| 157 |
+
return result
|
| 158 |
+
|
| 159 |
+
def __mod__(self, other):
|
| 160 |
+
if self.onnx_dtype in {
|
| 161 |
+
ir.DataType.FLOAT,
|
| 162 |
+
ir.DataType.DOUBLE,
|
| 163 |
+
ir.DataType.FLOAT16,
|
| 164 |
+
ir.DataType.BFLOAT16,
|
| 165 |
+
}:
|
| 166 |
+
return self._opset.Mod(self, other, fmod=1)
|
| 167 |
+
return self._opset.Mod(self, other)
|
| 168 |
+
|
| 169 |
+
def __ne__(self, other):
|
| 170 |
+
temp = self._opset.Equal(self, other)
|
| 171 |
+
return self._opset.Not(temp)
|
| 172 |
+
|
| 173 |
+
def __neg__(self):
|
| 174 |
+
return self._opset.Neg(self)
|
| 175 |
+
|
| 176 |
+
def __add__(self, other):
|
| 177 |
+
return self._opset.Add(self, other)
|
| 178 |
+
|
| 179 |
+
def __radd__(self, other):
|
| 180 |
+
return self._opset.Add(other, self)
|
| 181 |
+
|
| 182 |
+
def __and__(self, other):
|
| 183 |
+
return self._opset.And(self, other)
|
| 184 |
+
|
| 185 |
+
def __rand__(self, other):
|
| 186 |
+
return self._opset.And(other, self)
|
| 187 |
+
|
| 188 |
+
def __mul__(self, other):
|
| 189 |
+
return self._opset.Mul(self, other)
|
| 190 |
+
|
| 191 |
+
def __rmul__(self, other):
|
| 192 |
+
return self._opset.Mul(other, self)
|
| 193 |
+
|
| 194 |
+
def __matmul__(self, other):
|
| 195 |
+
return self._opset.MatMul(self, other)
|
| 196 |
+
|
| 197 |
+
def __or__(self, other):
|
| 198 |
+
return self._opset.Or(self, other)
|
| 199 |
+
|
| 200 |
+
def __pow__(self, other):
|
| 201 |
+
return self._opset.Pow(self, other)
|
| 202 |
+
|
| 203 |
+
def __sub__(self, other):
|
| 204 |
+
return self._opset.Sub(self, other)
|
| 205 |
+
|
| 206 |
+
def __rsub__(self, other):
|
| 207 |
+
return self._opset.Sub(other, self)
|
| 208 |
+
|
| 209 |
+
def __truediv__(self, other):
|
| 210 |
+
return self._opset.Div(self, other)
|
| 211 |
+
|
| 212 |
+
def __lt__(self, other):
|
| 213 |
+
return self._opset.Less(self, other)
|
| 214 |
+
|
| 215 |
+
def __le__(self, other):
|
| 216 |
+
return self._opset.LessOrEqual(self, other)
|
| 217 |
+
|
| 218 |
+
def __eq__(self, other):
|
| 219 |
+
return self._opset.Equal(self, other)
|
| 220 |
+
|
| 221 |
+
def __ge__(self, other):
|
| 222 |
+
return self._opset.GreaterOrEqual(self, other)
|
| 223 |
+
|
| 224 |
+
def __gt__(self, other):
|
| 225 |
+
return self._opset.Greater(self, other)
|
pythonProject/.venv/Lib/site-packages/onnxscript/type_annotation.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import collections
|
| 6 |
+
import inspect
|
| 7 |
+
import typing
|
| 8 |
+
from typing import Optional, Sequence, Union
|
| 9 |
+
|
| 10 |
+
import onnx
|
| 11 |
+
|
| 12 |
+
from onnxscript import onnx_types
|
| 13 |
+
from onnxscript._internal import version_utils
|
| 14 |
+
|
| 15 |
+
# TypeAnnotationValue represents the (value of) valid type-annotations recognized
|
| 16 |
+
# by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports
|
| 17 |
+
# - float, int, str (primitive attribute types)
|
| 18 |
+
# - Sequence[float], Sequence[int], Sequence[str] (attribute types)
|
| 19 |
+
# - Tensor types
|
| 20 |
+
# - Sequence[Tensor] types
|
| 21 |
+
# - Union of above 2
|
| 22 |
+
# - TypeVars with above bounds
|
| 23 |
+
# - Above types with annotation attached
|
| 24 |
+
TypeAnnotationValue = typing.Any
|
| 25 |
+
|
| 26 |
+
# Map from python type to corresponding ONNX AttributeProto type
|
| 27 |
+
_PYTYPE_TO_ATTRTYPE_MAP = {
|
| 28 |
+
float: onnx.AttributeProto.FLOAT,
|
| 29 |
+
int: onnx.AttributeProto.INT,
|
| 30 |
+
str: onnx.AttributeProto.STRING,
|
| 31 |
+
bool: onnx.AttributeProto.INT, # experimental
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# Map from python type to corresponding ONNX AttributeProto type,
|
| 35 |
+
# for repeated (i.e., list of) values
|
| 36 |
+
_LISTTYPE_TO_ATTRTYPE_MAP = {
|
| 37 |
+
float: onnx.AttributeProto.FLOATS,
|
| 38 |
+
int: onnx.AttributeProto.INTS,
|
| 39 |
+
str: onnx.AttributeProto.STRINGS,
|
| 40 |
+
bool: onnx.AttributeProto.INTS, # experimental
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])
|
| 44 |
+
|
| 45 |
+
# Map from ONNX AttributeProto type to its representation (in ONNX Script).
|
| 46 |
+
_ATTRTYPE_TO_REPR = {
|
| 47 |
+
onnx.AttributeProto.FLOAT: "float",
|
| 48 |
+
onnx.AttributeProto.INT: "int",
|
| 49 |
+
onnx.AttributeProto.STRING: "str",
|
| 50 |
+
onnx.AttributeProto.FLOATS: "Sequence[float]",
|
| 51 |
+
onnx.AttributeProto.INTS: "Sequence[int]",
|
| 52 |
+
onnx.AttributeProto.STRINGS: "Sequence[str]",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str:
|
| 57 |
+
if attr_type not in _ATTRTYPE_TO_REPR:
|
| 58 |
+
supported = ", ".join(
|
| 59 |
+
f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR
|
| 60 |
+
)
|
| 61 |
+
raise ValueError(f"Unsupported attribute type {attr_type}: only {supported} allowed.")
|
| 62 |
+
return _ATTRTYPE_TO_REPR[attr_type]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# A sorted list of all type strings used in an OpSchema
|
| 66 |
+
ALL_TENSOR_TYPE_STRINGS = tuple(
|
| 67 |
+
sorted(
|
| 68 |
+
tensor_type.to_string()
|
| 69 |
+
for tensor_type in onnx_types.tensor_type_registry.values()
|
| 70 |
+
# Skip FLOAT4E2M1 for versions older than 1.18
|
| 71 |
+
# TODO(after onnx requirement bump): Remove this check
|
| 72 |
+
if not (version_utils.onnx_older_than("1.18") and tensor_type == onnx_types.FLOAT4E2M1)
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue:
|
| 78 |
+
"""Remove Annotated wrapper if present, otherwise return typeinfo as is."""
|
| 79 |
+
if hasattr(typing, "Annotated"):
|
| 80 |
+
# Present in Python 3.9+
|
| 81 |
+
if typing.get_origin(typeinfo) is typing.Annotated:
|
| 82 |
+
return typing.get_args(typeinfo)[0]
|
| 83 |
+
return typeinfo
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
|
| 87 |
+
return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def pytype_to_attrtype(
|
| 91 |
+
pytype: TypeAnnotationValue,
|
| 92 |
+
) -> Optional[onnx.AttributeProto.AttributeType]:
|
| 93 |
+
pytype = _remove_annotation(pytype)
|
| 94 |
+
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
|
| 95 |
+
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
|
| 96 |
+
type_constructor = typing.get_origin(pytype)
|
| 97 |
+
# Remove Optional wrapper if present, which is represented as an Union[..., type(None)]
|
| 98 |
+
if type_constructor is typing.Union:
|
| 99 |
+
# Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)]
|
| 100 |
+
args = [x for x in typing.get_args(pytype) if x is not type(None)]
|
| 101 |
+
if len(args) == 1:
|
| 102 |
+
return pytype_to_attrtype(args[0])
|
| 103 |
+
if type_constructor in _LIST_CONSTRUCTORS:
|
| 104 |
+
elt_type = typing.get_args(pytype)[0]
|
| 105 |
+
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
|
| 106 |
+
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def base_type_is_bool(pytype: TypeAnnotationValue) -> bool:
|
| 111 |
+
"""Returns True if base type of pytype is bool, False otherwise."""
|
| 112 |
+
pytype = _remove_annotation(pytype)
|
| 113 |
+
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
|
| 114 |
+
return pytype is bool
|
| 115 |
+
type_constructor = typing.get_origin(pytype)
|
| 116 |
+
if type_constructor in _LIST_CONSTRUCTORS:
|
| 117 |
+
element_type = typing.get_args(pytype)[0]
|
| 118 |
+
return element_type is bool
|
| 119 |
+
# Remove Optional wrapper if present:
|
| 120 |
+
if type_constructor is Optional or type_constructor is Union:
|
| 121 |
+
# In Python < 3.10, Optional[X] is represented as Union[X, type(None)]
|
| 122 |
+
# so we filter out type(None) if present
|
| 123 |
+
args = [x for x in typing.get_args(pytype) if x is not type(None)]
|
| 124 |
+
if len(args) == 1:
|
| 125 |
+
return base_type_is_bool(args[0])
|
| 126 |
+
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:
|
| 131 |
+
if isinstance(typeinfo, onnx_types.TensorType):
|
| 132 |
+
return True
|
| 133 |
+
if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType):
|
| 134 |
+
return True
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def is_value_type(typeinfo: TypeAnnotationValue) -> bool:
|
| 139 |
+
"""Returns True if typeinfo represents a value type, False if it is an attribute type.
|
| 140 |
+
Raises ValueError if typeinfo is not a supported type annotation.
|
| 141 |
+
"""
|
| 142 |
+
typeinfo = _remove_annotation(typeinfo)
|
| 143 |
+
if _is_tensor_type(typeinfo):
|
| 144 |
+
return True
|
| 145 |
+
if _is_primitive_attr_type(typeinfo):
|
| 146 |
+
return False
|
| 147 |
+
type_constructor = typing.get_origin(typeinfo)
|
| 148 |
+
# Handle List-like type-constructor
|
| 149 |
+
# Eg. List[INT32] is a value type, while List[int] is an attribute type
|
| 150 |
+
if type_constructor in _LIST_CONSTRUCTORS:
|
| 151 |
+
elt_type = typing.get_args(typeinfo)[0]
|
| 152 |
+
return is_value_type(elt_type)
|
| 153 |
+
# Handle Union and Optional type-constructors
|
| 154 |
+
if type_constructor is typing.Union:
|
| 155 |
+
# Filter out None, since typing.Optional[X] evaluates to Union[X, None]
|
| 156 |
+
args = [x for x in typing.get_args(typeinfo) if x is not type(None)]
|
| 157 |
+
args_value_check = [is_value_type(x) for x in args]
|
| 158 |
+
if all(args_value_check):
|
| 159 |
+
# Handles cases like Optional[INT32] as well as Union[FLOAT16, FLOAT, DOUBLE]
|
| 160 |
+
return True
|
| 161 |
+
elif (len(args) == 1) and args_value_check[0] is False:
|
| 162 |
+
# Handle the case of optional attribute: eg. Optional[int]
|
| 163 |
+
# Note that we do not allow Union[int, float] for attributes.
|
| 164 |
+
return False
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Unsupported type annotation '{typeinfo}'")
|
| 167 |
+
# Handle TypeVars:
|
| 168 |
+
if isinstance(typeinfo, typing.TypeVar):
|
| 169 |
+
if hasattr(typeinfo, "__bound__"):
|
| 170 |
+
bound = typeinfo.__bound__
|
| 171 |
+
return is_value_type(bound)
|
| 172 |
+
raise ValueError(f"Unsupported type annotation {typeinfo}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def is_attr_type(pytype: TypeAnnotationValue):
|
| 176 |
+
return is_value_type(pytype) is False
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def is_valid_type(typeinfo: TypeAnnotationValue):
|
| 180 |
+
try:
|
| 181 |
+
return is_value_type(typeinfo) in {True, False}
|
| 182 |
+
except ValueError:
|
| 183 |
+
return False
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def is_optional(pytype) -> bool:
|
| 187 |
+
"""Returns whether a pytype is an Optional."""
|
| 188 |
+
if typing.get_origin(pytype) is Union and type(None) in typing.get_args(pytype):
|
| 189 |
+
# Python < 3.10
|
| 190 |
+
return True
|
| 191 |
+
if typing.get_origin(pytype) is Optional:
|
| 192 |
+
# Python >= 3.10
|
| 193 |
+
return True
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]:
|
| 198 |
+
"""Converts return-type annotation into a sequence of types.
|
| 199 |
+
|
| 200 |
+
The return type annotation can be either a single type (for a single output)
|
| 201 |
+
or a Tuple type (for multiple outputs). This function normalizes the
|
| 202 |
+
representation so that it is always a sequence of types, even for a single
|
| 203 |
+
output.
|
| 204 |
+
"""
|
| 205 |
+
if isinstance(typeinfo, typing.Sequence):
|
| 206 |
+
return typeinfo
|
| 207 |
+
if typing.get_origin(typeinfo) is tuple:
|
| 208 |
+
return typing.get_args(typeinfo)
|
| 209 |
+
return (typeinfo,)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]:
|
| 213 |
+
"""Returns a list of type-strings corresponding to a given type annotation.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
pytype: A type annotation.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
A list of all supported input types for the given type annotation.
|
| 220 |
+
Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS.
|
| 221 |
+
"""
|
| 222 |
+
if pytype is None:
|
| 223 |
+
return list(ALL_TENSOR_TYPE_STRINGS)
|
| 224 |
+
if pytype is onnx_types.TensorType:
|
| 225 |
+
return list(ALL_TENSOR_TYPE_STRINGS)
|
| 226 |
+
if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType):
|
| 227 |
+
return [pytype.to_string()]
|
| 228 |
+
if isinstance(pytype, onnx_types.TensorType):
|
| 229 |
+
return [pytype.to_string()]
|
| 230 |
+
if isinstance(pytype, typing.TypeVar):
|
| 231 |
+
constraints = pytype.__constraints__
|
| 232 |
+
if constraints:
|
| 233 |
+
return pytype_to_type_strings(Union.__getitem__(constraints)) # pylint: disable=unnecessary-dunder-call
|
| 234 |
+
bound = pytype.__bound__
|
| 235 |
+
if bound is None:
|
| 236 |
+
return list(ALL_TENSOR_TYPE_STRINGS)
|
| 237 |
+
return pytype_to_type_strings(bound)
|
| 238 |
+
if typing.get_origin(pytype) is Union:
|
| 239 |
+
options = []
|
| 240 |
+
subtypes = typing.get_args(pytype)
|
| 241 |
+
# A None type in a Union is equivalent to an optional type
|
| 242 |
+
optional = is_optional(pytype)
|
| 243 |
+
for subtype in subtypes:
|
| 244 |
+
if subtype is type(None):
|
| 245 |
+
# Skip None type because we are handling it with is_optional
|
| 246 |
+
continue
|
| 247 |
+
if optional:
|
| 248 |
+
options += [
|
| 249 |
+
*pytype_to_type_strings(subtype),
|
| 250 |
+
*[f"optional({s})" for s in pytype_to_type_strings(subtype)],
|
| 251 |
+
]
|
| 252 |
+
else:
|
| 253 |
+
options += pytype_to_type_strings(subtype)
|
| 254 |
+
# Remove duplicates
|
| 255 |
+
return sorted(set(options))
|
| 256 |
+
if typing.get_origin(pytype) in _LIST_CONSTRUCTORS:
|
| 257 |
+
subtypes = typing.get_args(pytype)
|
| 258 |
+
return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])]
|
| 259 |
+
|
| 260 |
+
raise ValueError(f"Unsupported type: {pytype}")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]:
|
| 264 |
+
"""Returns the name of the type constraint for a given type annotation.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
pytype: A type annotation.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
The name of the type constraint if it is a TypeVar.
|
| 271 |
+
- Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar].
|
| 272 |
+
- Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
|
| 273 |
+
- Returns None if the type annotation does not have a type constraint.
|
| 274 |
+
"""
|
| 275 |
+
if isinstance(pytype, typing.TypeVar):
|
| 276 |
+
return pytype.__name__
|
| 277 |
+
if is_optional(pytype):
|
| 278 |
+
subtypes = typing.get_args(pytype)
|
| 279 |
+
for subtype in subtypes:
|
| 280 |
+
if subtype is type(None):
|
| 281 |
+
continue
|
| 282 |
+
type_param_name = get_type_constraint_name(subtype)
|
| 283 |
+
return f"Optional_{type_param_name}" if type_param_name else None
|
| 284 |
+
if typing.get_origin(pytype) in _LIST_CONSTRUCTORS:
|
| 285 |
+
subtypes = typing.get_args(pytype)
|
| 286 |
+
type_param_name = get_type_constraint_name(subtypes[0])
|
| 287 |
+
return f"Sequence_{type_param_name}" if type_param_name else None
|
| 288 |
+
return None
|